Skip to content

Commit

Permalink
Merge pull request #585 from stan-dev/expose-diagnostics
Browse files Browse the repository at this point in the history
New method summarizing sampler diagnostics and warnings
  • Loading branch information
jgabry authored Mar 15, 2022
2 parents 17c5e88 + 1c458ab commit ac5a448
Show file tree
Hide file tree
Showing 19 changed files with 561 additions and 117 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ Authors@R:
person(given = "Mikhail", family = "Popov", role = "ctb"),
person(given = "Mike", family = "Lawrence", role = "ctb"),
person(given = c("William", "Michael"), family = "Landau", role = "ctb",
email = "[email protected]", comment = c(ORCID = "0000-0003-1878-3253")))
email = "[email protected]", comment = c(ORCID = "0000-0003-1878-3253")),
person(given = "Jacob", family = "Socolar", role = "ctb"))
Description: A lightweight interface to 'Stan' <https://mc-stan.org>.
The 'CmdStanR' interface is an alternative to 'RStan' that calls the command
line interface for compilation and running algorithms instead of interfacing
Expand Down
12 changes: 10 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ recompilation of Stan models. (#580)
* New methods for `posterior::as_draws()` for CmdStanR fitted model objects.
These are just wrappers around the `$draws()` method provided for convenience. (#532)

* Added E-BFMI checks that run automatically post sampling. (#500, @jsocolar)

* New method `$diagnostic_summary()` that summarizes the sampler diagnostics
(divergences, treedepth, ebfmi) and can regenerate the related warning messages. (#205)

* New `diagnostics` argument for the `$sample()` method to specify which
diagnostics are checked after sampling. Replaces `validate_csv` argument. (#205)

* New function `as_mcmc.list()` for converting CmdStanMCMC objects to mcmc.list
objects from the coda package. (#584, @MatsuuraKentaro)

Expand Down Expand Up @@ -91,7 +99,7 @@ Stan programs requires CmdStan >= 2.26. (#434)

* New vignette on profiling Stan programs. (#435)

* New vignette on running Stan on the GPU with OpenCL. OpenCL device ids can
* New vignette on running Stan on the GPU with OpenCL. OpenCL device ids can
now also be specified at runtime. (#439)

* New check for invalid parameter names when supplying init values. (#452, @mike-lawrence)
Expand All @@ -101,7 +109,7 @@ now also be specified at runtime. (#439)
* New `error_on_NA` argument for `cmdstan_version()` to optionally return `NULL`
(instead of erroring) if the CmdStan path is not found (#467, @wlandau).

* Global option `cmdstanr_max_rows` can be set as an alternative to specifying
* Global option `cmdstanr_max_rows` can be set as an alternative to specifying
`max_rows` argument to the `$print()` method. (#470)

* New `output_basename` argument for the model fitting methods. Can be used in
Expand Down
15 changes: 12 additions & 3 deletions R/args.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ CmdStanArgs <- R6::R6Class(
refresh = NULL,
output_dir = NULL,
output_basename = NULL,
validate_csv = TRUE,
sig_figs = NULL,
opencl_ids = NULL,
model_variables = NULL) {
Expand All @@ -52,7 +51,6 @@ CmdStanArgs <- R6::R6Class(
self$method_args <- method_args
self$method <- self$method_args$method
self$save_latent_dynamics <- save_latent_dynamics
self$validate_csv <- validate_csv
self$using_tempdir <- is.null(output_dir)
if (getRversion() < "3.5.0") {
self$output_dir <- output_dir %||% tempdir()
Expand Down Expand Up @@ -196,7 +194,8 @@ SampleArgs <- R6::R6Class(
init_buffer = NULL,
term_buffer = NULL,
window = NULL,
fixed_param = FALSE) {
fixed_param = FALSE,
diagnostics = NULL) {

self$iter_warmup <- iter_warmup
self$iter_sampling <- iter_sampling
Expand All @@ -209,6 +208,11 @@ SampleArgs <- R6::R6Class(
self$metric <- metric
self$inv_metric <- inv_metric
self$fixed_param <- fixed_param
self$diagnostics <- diagnostics
if (identical(self$diagnostics, "")) {
self$diagnostics <- NULL
}

if (!is.null(inv_metric)) {
if (!is.null(metric_file)) {
stop("Only one of inv_metric and metric_file can be specified.",
Expand Down Expand Up @@ -636,6 +640,11 @@ validate_sample_args <- function(self, num_procs) {
validate_metric(self$metric)
validate_metric_file(self$metric_file, num_procs)

checkmate::assert_character(self$diagnostics, null.ok = TRUE, any.missing = FALSE)
if (!is.null(self$diagnostics)) {
checkmate::assert_subset(self$diagnostics, empty.ok = FALSE, choices = available_hmc_diagnostics())
}

invisible(TRUE)
}

Expand Down
10 changes: 6 additions & 4 deletions R/csv.R
Original file line number Diff line number Diff line change
Expand Up @@ -445,10 +445,6 @@ CmdStanMCMC_CSV <- R6::R6Class(
inherit = CmdStanMCMC,
public = list(
initialize = function(csv_contents, files, check_diagnostics = TRUE) {
if (check_diagnostics) {
check_divergences(csv_contents$post_warmup_sampler_diagnostics)
check_sampler_transitions_treedepth(csv_contents$post_warmup_sampler_diagnostics, csv_contents$metadata)
}
private$output_files_ <- files
private$metadata_ <- csv_contents$metadata
private$time_ <- csv_contents$time
Expand All @@ -457,6 +453,10 @@ CmdStanMCMC_CSV <- R6::R6Class(
private$warmup_sampler_diagnostics_ <- csv_contents$warmup_sampler_diagnostics
private$warmup_draws_ <- csv_contents$warmup_draws
private$draws_ <- csv_contents$post_warmup_draws
if (check_diagnostics) {
invisible(self$diagnostic_summary())
}
invisible(self)
},
# override some methods so they work without a CmdStanRun object
output_files = function(...) {
Expand All @@ -482,6 +482,7 @@ CmdStanMLE_CSV <- R6::R6Class(
private$output_files_ <- files
private$draws_ <- csv_contents$point_estimates
private$metadata_ <- csv_contents$metadata
invisible(self)
},
output_files = function(...) {
private$output_files_
Expand All @@ -497,6 +498,7 @@ CmdStanVB_CSV <- R6::R6Class(
private$output_files_ <- files
private$draws_ <- csv_contents$draws
private$metadata_ <- csv_contents$metadata
invisible(self)
},
output_files = function(...) {
private$output_files_
Expand Down
99 changes: 90 additions & 9 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,7 @@ CmdStanFit$set("public", name = "code", value = code)
#' |**Method**|**Description**|
#' |:----------|:---------------|
#' [`$summary()`][fit-method-summary] | Run [`posterior::summarise_draws()`][posterior::draws_summary]. |
#' [`$diagnostic_summary()`][fit-method-diagnostic_summary] | Get summaries of sampler diagnostics and warning messages. |
#' [`$cmdstan_summary()`][fit-method-cmdstan_summary] | Run and print CmdStan's `bin/stansummary`. |
#' [`$cmdstan_diagnose()`][fit-method-cmdstan_summary] | Run and print CmdStan's `bin/diagnose`. |
#' [`$loo()`][fit-method-loo] | Run [loo::loo.array()] for approximate LOO-CV |
Expand Down Expand Up @@ -856,14 +857,15 @@ CmdStanMCMC <- R6::R6Class(
warning("No chains finished successfully. Unable to retrieve the fit.",
call. = FALSE)
} else {
if (self$runset$args$validate_csv) {
fixed_param <- runset$args$method_args$fixed_param
private$read_csv_(variables = "",
sampler_diagnostics = if (!fixed_param) c("treedepth__", "divergent__") else "")
if (!fixed_param) {
check_divergences(private$sampler_diagnostics_)
check_sampler_transitions_treedepth(private$sampler_diagnostics_, private$metadata_)
}
if (runset$args$method_args$fixed_param) {
private$read_csv_(variables = "", sampler_diagnostics = "")
} else {
diagnostics <- self$runset$args$method_args$diagnostics
private$read_csv_(
variables = "",
sampler_diagnostics = convert_hmc_diagnostic_names(diagnostics)
)
invisible(self$diagnostic_summary(diagnostics, quiet = FALSE))
}
}
},
Expand Down Expand Up @@ -1047,7 +1049,9 @@ CmdStanMCMC$set("public", name = "loo", value = loo)
#' @name fit-method-sampler_diagnostics
#' @aliases sampler_diagnostics
#' @description Extract the values of sampler diagnostics for each iteration and
#' chain of MCMC.
#' chain of MCMC. To instead get summaries of these diagnostics and associated
#' warning messages use the
#' [`$diagnostic_summary()`][fit-method-diagnostic_summary] method.
#'
#' @param inc_warmup (logical) Should warmup draws be included? Defaults to `FALSE`.
#' @param format (string) The draws format to return. See
Expand Down Expand Up @@ -1106,6 +1110,83 @@ sampler_diagnostics <- function(inc_warmup = FALSE, format = getOption("cmdstanr
}
CmdStanMCMC$set("public", name = "sampler_diagnostics", value = sampler_diagnostics)

#' Sampler diagnostic summaries and warnings
#'
#' @name fit-method-diagnostic_summary
#' @aliases diagnostic_summary
#' @description Warnings and summaries of sampler diagnostics. To instead get
#' the underlying values of the sampler diagnostics for each iteration and
#' chain use the [`$sampler_diagnostics()`][fit-method-sampler_diagnostics]
#' method.
#'
#' Currently parameter-specific diagnostics like R-hat and effective sample
#' size are _not_ handled by this method. Those diagnostics are provided via
#' the [`$summary()`][fit-method-summary] method (using
#' [posterior::summarize_draws()]).
#'
#' @param diagnostics (character vector) One or more diagnostics to check. The
#' currently supported diagnostics are `"divergences`, `"treedepth"`, and
#' `"ebfmi`. The default is to check all of them.
#' @param quiet (logical) Should warning messages about the diagnostics be
#' suppressed? The default is `FALSE`, in which case warning messages are
#' printed in addition to returning the values of the diagnostics.
#'
#' @return A list with as many named elements as `diagnostics` selected. The
#' possible elements and their values are:
#' * `"num_divergent"`: A vector of the number of divergences per chain.
#' * `"num_max_treedepth"`: A vector of the number of times `max_treedepth` was hit per chain.
#' * `"ebfmi"`: A vector of E-BFMI values per chain.
#'
#' @seealso [`CmdStanMCMC`] and the
#' [`$sampler_diagnostics()`][fit-method-sampler_diagnostics] method
#'
#' @examples
#' \dontrun{
#' fit <- cmdstanr_example("schools")
#' fit$diagnostic_summary()
#' fit$diagnostic_summary(quiet = TRUE)
#' }
#'
diagnostic_summary <- function(diagnostics = c("divergences", "treedepth", "ebfmi"), quiet = FALSE) {
out <- list()
if (is.null(diagnostics) || identical(diagnostics, "")) {
return(out)
}
diagnostics <- match.arg(
diagnostics,
choices = available_hmc_diagnostics(),
several.ok = TRUE
)
post_warmup_sampler_diagnostics <- self$sampler_diagnostics(inc_warmup = FALSE)
if ("divergences" %in% diagnostics) {
if (quiet) {
divergences <- suppressMessages(check_divergences(post_warmup_sampler_diagnostics))
} else {
divergences <- check_divergences(post_warmup_sampler_diagnostics)
}
out[["num_divergent"]] <- divergences
}
if ("treedepth" %in% diagnostics) {
if (quiet) {
max_treedepth_hit <- suppressMessages(check_max_treedepth(post_warmup_sampler_diagnostics, self$metadata()))
} else {
max_treedepth_hit <- check_max_treedepth(post_warmup_sampler_diagnostics, self$metadata())
}
out[["num_max_treedepth"]] <- max_treedepth_hit
}
if ("ebfmi" %in% diagnostics) {
if (quiet) {
ebfmi <- suppressMessages(check_ebfmi(post_warmup_sampler_diagnostics))
} else {
ebfmi <- check_ebfmi(post_warmup_sampler_diagnostics)
}
out[["ebfmi"]] <- ebfmi %||% NA
}
out
}
CmdStanMCMC$set("public", name = "diagnostic_summary", value = diagnostic_summary)


#' Extract inverse metric (mass matrix) after MCMC
#'
#' @name fit-method-inv_metric
Expand Down
45 changes: 37 additions & 8 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -934,9 +934,12 @@ CmdStanModel$set("public", name = "format", value = format)
#' [CmdStan User’s Guide](https://mc-stan.org/docs/cmdstan-guide/)
#' for more details.
#'
#' After model fitting any diagnostics specified via the `diagnostics`
#' argument will be checked and warnings will be printed if warranted.
#'
#' @template model-common-args
#' @template model-sample-args
#' @param cores,num_cores,num_chains,num_warmup,num_samples,save_extra_diagnostics,max_depth,stepsize
#' @param cores,num_cores,num_chains,num_warmup,num_samples,save_extra_diagnostics,max_depth,stepsize,validate_csv
#' Deprecated and will be removed in a future release.
#'
#' @return A [`CmdStanMCMC`] object.
Expand Down Expand Up @@ -972,14 +975,15 @@ sample <- function(data = NULL,
term_buffer = NULL,
window = NULL,
fixed_param = FALSE,
validate_csv = TRUE,
show_messages = TRUE,
diagnostics = c("divergences", "treedepth", "ebfmi"),
# deprecated
cores = NULL,
num_cores = NULL,
num_chains = NULL,
num_warmup = NULL,
num_samples = NULL,
validate_csv = NULL,
save_extra_diagnostics = NULL,
max_depth = NULL,
stepsize = NULL) {
Expand Down Expand Up @@ -1016,6 +1020,17 @@ sample <- function(data = NULL,
warning("'save_extra_diagnostics' is deprecated. Please use 'save_latent_dynamics' instead.")
save_latent_dynamics <- save_extra_diagnostics
}
if (!is.null(validate_csv)) {
warning("'validate_csv' is deprecated. Please use 'diagnostics' instead.")
if (is.logical(validate_csv)) {
if (validate_csv) {
diagnostics <- c("divergences", "treedepth", "ebfmi")
} else {
diagnostics <- NULL
}
}
}

if (cmdstan_version() >= "2.27.0" && !fixed_param) {
if (self$has_stan_file() && file.exists(self$stan_file())) {
if (!is.null(self$variables()) && length(self$variables()$parameters) == 0) {
Expand Down Expand Up @@ -1051,7 +1066,8 @@ sample <- function(data = NULL,
init_buffer = init_buffer,
term_buffer = term_buffer,
window = window,
fixed_param = fixed_param
fixed_param = fixed_param,
diagnostics = diagnostics
)
args <- CmdStanArgs$new(
method_args = sample_args,
Expand All @@ -1068,7 +1084,6 @@ sample <- function(data = NULL,
output_dir = output_dir,
output_basename = output_basename,
sig_figs = sig_figs,
validate_csv = validate_csv,
opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()),
model_variables = model_variables
)
Expand Down Expand Up @@ -1163,8 +1178,22 @@ sample_mpi <- function(data = NULL,
window = NULL,
fixed_param = FALSE,
sig_figs = NULL,
validate_csv = TRUE,
show_messages = TRUE) {
show_messages = TRUE,
diagnostics = c("divergences", "treedepth", "ebfmi"),
# deprecated
validate_csv = TRUE) {

if (!is.null(validate_csv)) {
warning("'validate_csv' is deprecated. Please use 'diagnostics' instead.")
if (is.logical(validate_csv)) {
if (validate_csv) {
diagnostics <- c("divergences", "treedepth", "ebfmi")
} else {
diagnostics <- NULL
}
}
}

if (fixed_param) {
chains <- 1
save_warmup <- FALSE
Expand Down Expand Up @@ -1193,7 +1222,8 @@ sample_mpi <- function(data = NULL,
init_buffer = init_buffer,
term_buffer = term_buffer,
window = window,
fixed_param = fixed_param
fixed_param = fixed_param,
diagnostics = diagnostics
)
args <- CmdStanArgs$new(
method_args = sample_args,
Expand All @@ -1209,7 +1239,6 @@ sample_mpi <- function(data = NULL,
refresh = refresh,
output_dir = output_dir,
output_basename = output_basename,
validate_csv = validate_csv,
sig_figs = sig_figs,
model_variables = model_variables
)
Expand Down
2 changes: 1 addition & 1 deletion R/run.R
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,7 @@ CmdStanMCMCProcs <- R6::R6Class(
"seconds.\n")
cat("Total execution time:",
base::format(round(self$total_time(), 1), nsmall = 1),
"seconds.\n")
"seconds.\n\n")
} else if (num_failed == num_chains) {
warning("All chains finished unexpectedly! Use the $output(chain_id) method for more information.\n", call. = FALSE)
warning("Use read_cmdstan_csv() to read the results of the failed chains.",
Expand Down
Loading

0 comments on commit ac5a448

Please sign in to comment.