From e70b844eefca4107c8a7fbd1170e292ae9691d8f Mon Sep 17 00:00:00 2001 From: Jacob Socolar Date: Fri, 14 May 2021 21:48:01 -0500 Subject: [PATCH 01/44] Add check_bfmi function --- R/utils.R | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/R/utils.R b/R/utils.R index b52ddc5b4..bf41e2d66 100644 --- a/R/utils.R +++ b/R/utils.R @@ -291,6 +291,20 @@ check_sampler_transitions_treedepth <- function(post_warmup_sampler_diagnostics, } } +check_bfmi <- function(post_warmup_sampler_diagnostics) { + if (!is.null(post_warmup_sampler_diagnostics)) { + energy <- posterior::extract_variable_matrix(post_warmup_sampler_diagnostics, "energy__") + ebfmi <- apply(energy, 2, function(x) { + (sum(diff(x)^2)/length(x))/var(x) + }) + if (any(ebfmi < .3)) { + message(sum(ebfmi < .3), " of ", length(ebfmi) , " chains had estimated Bayesian fraction + of missing information(E-BFMI) less than 0.3, which may indicate poor exploration of the + posterior. Try to reparameterize the model.") + } + } +} + matching_variables <- function(variable_filters, variables) { not_found <- c() selected_variables <- c() From fdd6d738edb86317e9cf56ba59620c4bfc95ea81 Mon Sep 17 00:00:00 2001 From: rok-cesnovar Date: Sat, 15 May 2021 19:04:38 +0200 Subject: [PATCH 02/44] run check_bfmi after sampling --- R/csv.R | 1 + R/fit.R | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/R/csv.R b/R/csv.R index e7127b10d..cbcd1a8fc 100644 --- a/R/csv.R +++ b/R/csv.R @@ -438,6 +438,7 @@ CmdStanMCMC_CSV <- R6::R6Class( if (check_diagnostics) { check_divergences(csv_contents$post_warmup_sampler_diagnostics) check_sampler_transitions_treedepth(csv_contents$post_warmup_sampler_diagnostics, csv_contents$metadata) + check_bfmi(csv_contents$post_warmup_sampler_diagnostics) } private$output_files_ <- files private$metadata_ <- csv_contents$metadata diff --git a/R/fit.R b/R/fit.R index c0e564778..7f6345d21 100644 --- a/R/fit.R +++ b/R/fit.R @@ -829,10 +829,11 @@ CmdStanMCMC <- R6::R6Class( if (self$runset$args$validate_csv) { fixed_param <- runset$args$method_args$fixed_param private$read_csv_(variables = "", - sampler_diagnostics = if (!fixed_param) c("treedepth__", "divergent__") else "") + sampler_diagnostics = if (!fixed_param) c("treedepth__", "divergent__", "energy__") else "") if (!fixed_param) { check_divergences(private$sampler_diagnostics_) check_sampler_transitions_treedepth(private$sampler_diagnostics_, private$metadata_) + check_bfmi(private$sampler_diagnostics_) } } } From 430d63130182967e0212e4be53163ee8f1b7c653 Mon Sep 17 00:00:00 2001 From: jgabry Date: Sun, 16 May 2021 19:11:51 -0600 Subject: [PATCH 03/44] merge in master and fix conflict in utils.R --- DESCRIPTION | 3 +- NEWS.md | 4 + R/args.R | 17 +-- R/csv.R | 71 ++++++++---- R/data.R | 48 +++++---- R/example.R | 50 +++++++-- R/fit.R | 22 ++-- R/install.R | 133 ++++++++++++++--------- R/knitr.R | 1 - R/model.R | 50 +++++++-- R/path.R | 29 ++--- R/run.R | 28 ++--- R/utils.R | 161 ++++++++-------------------- R/zzz.R | 10 +- man/cmdstan_default_path.Rd | 7 +- man/install_cmdstan.Rd | 4 + man/set_cmdstan_path.Rd | 4 +- man/write_stan_file.Rd | 35 ++++-- man/write_stan_tempfile.Rd | 9 +- tests/testthat/test-csv.R | 26 +++++ tests/testthat/test-example.R | 32 ++++++ tests/testthat/test-model-compile.R | 31 ++++-- tests/testthat/test-model-init.R | 2 +- tests/testthat/test-utils.R | 39 ------- 24 files changed, 471 insertions(+), 345 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 6bfce2b3f..0a4f0da43 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -31,7 +31,8 @@ Imports: jsonlite (>= 1.2.0), posterior (>= 0.1.5), processx (>= 3.5.0), - R6 (>= 2.4.0) + R6 (>= 2.4.0), + rlang (>= 0.4.7) Suggests: bayesplot, knitr, diff --git a/NEWS.md b/NEWS.md index 8bc223764..26881f439 100644 --- a/NEWS.md +++ b/NEWS.md @@ -3,6 +3,10 @@ * Expose CmdStan's `diagnose` method that compares Stan's gradient computations to gradients computed via finite differences. (#485) +* `write_stan_file()` now choose file names deterministically based on the code +so that models do not get unnecessarily recompiled when calling the function +multiple times with the same code. (#495, @martinmodrak) + # cmdstanr 0.4.0 ### Bug fixes diff --git a/R/args.R b/R/args.R index 5bd49f5b8..a38ac930e 100644 --- a/R/args.R +++ b/R/args.R @@ -48,7 +48,7 @@ CmdStanArgs <- R6::R6Class( self$save_latent_dynamics <- save_latent_dynamics self$validate_csv <- validate_csv self$using_tempdir <- is.null(output_dir) - if (getRversion() < '3.5.0') { + if (getRversion() < "3.5.0") { self$output_dir <- output_dir %||% tempdir() } else { self$output_dir <- output_dir %||% tempdir(check = TRUE) @@ -90,7 +90,7 @@ CmdStanArgs <- R6::R6Class( if (type == "output" && !is.null(self$output_basename)) { basename <- self$output_basename } - generate_file_names( # defined in utils.R + generate_file_names( basename = basename, ext = ".csv", ids = self$proc_ids, @@ -508,7 +508,7 @@ DiagnoseArgs <- R6::R6Class( #' @noRd #' @param self A `CmdStanArgs` object. #' @return `TRUE` invisibly unless an error is thrown. -validate_cmdstan_args = function(self) { +validate_cmdstan_args <- function(self) { validate_exe_file(self$exe_file) checkmate::assert_directory_exists(self$output_dir, access = "rw") @@ -755,7 +755,7 @@ validate_exe_file <- function(exe_file) { if (!length(exe_file) || !nzchar(exe_file) || !file.exists(exe_file)) { - stop('Model not compiled. Try running the compile() method first.', + stop("Model not compiled. Try running the compile() method first.", call. = FALSE) } invisible(TRUE) @@ -776,8 +776,13 @@ process_init_list <- function(init, num_procs) { if (any(sapply(init, function(x) length(x) == 0))) { stop("'init' contains empty lists.", call. = FALSE) } - if (any(grepl("\\[",names(unlist(init))))) { - stop("'init' contains entries with parameter names that include square-brackets, which is not permitted. To supply inits for a vector, matrix or array of parameters, create a single entry with the parameter's name in the init list and specify init values for the entire parameter container.", call. = FALSE) + if (any(grepl("\\[", names(unlist(init))))) { + stop( + "'init' contains entries with parameter names that include square-brackets, which is not permitted. ", + "To supply inits for a vector, matrix or array of parameters, ", + "create a single entry with the parameter's name in the 'init' list ", + "and specify initial values for the entire parameter container.", + call. = FALSE) } init_paths <- tempfile( diff --git a/R/csv.R b/R/csv.R index cbcd1a8fc..264886e0a 100644 --- a/R/csv.R +++ b/R/csv.R @@ -210,7 +210,7 @@ read_cmdstan_csv <- function(files, selected_sampler_diag <- rep(FALSE, length(metadata$sampler_diagnostics)) not_found <- NULL for (p in sampler_diagnostics) { - matches <- metadata$sampler_diagnostics == p | startsWith(metadata$sampler_diagnostics, paste0(p,".")) + matches <- metadata$sampler_diagnostics == p | startsWith(metadata$sampler_diagnostics, paste0(p, ".")) if (!any(matches)) { not_found <- c(not_found, p) } @@ -243,10 +243,10 @@ read_cmdstan_csv <- function(files, ) if (metadata$method == "sample" && metadata$save_warmup == 1 && num_warmup_draws > 0) { warmup_sampler_diagnostics[[warmup_sd_id]] <- - post_warmup_sampler_diagnostics[[post_warmup_sd_id]][1:num_warmup_draws,,drop = FALSE] + post_warmup_sampler_diagnostics[[post_warmup_sd_id]][1:num_warmup_draws, , drop = FALSE] if (num_post_warmup_draws > 0) { post_warmup_sampler_diagnostics[[post_warmup_sd_id]] <- - post_warmup_sampler_diagnostics[[post_warmup_sd_id]][(num_warmup_draws+1):(num_warmup_draws + num_post_warmup_draws),,drop = FALSE] + post_warmup_sampler_diagnostics[[post_warmup_sd_id]][(num_warmup_draws + 1):(num_warmup_draws + num_post_warmup_draws), , drop = FALSE] } else { post_warmup_sampler_diagnostics[[post_warmup_sd_id]] <- NULL } @@ -264,9 +264,9 @@ read_cmdstan_csv <- function(files, ) if (metadata$method == "sample" && metadata$save_warmup == 1 && num_warmup_draws > 0) { warmup_draws[[warmup_draws_list_id]] <- - draws[[draws_list_id]][1:num_warmup_draws,,drop = FALSE] + draws[[draws_list_id]][1:num_warmup_draws, , drop = FALSE] if (num_post_warmup_draws > 0) { - draws[[draws_list_id]] <- draws[[draws_list_id]][(num_warmup_draws+1):(num_warmup_draws + num_post_warmup_draws),,drop = FALSE] + draws[[draws_list_id]] <- draws[[draws_list_id]][(num_warmup_draws + 1):(num_warmup_draws + num_post_warmup_draws), , drop = FALSE] } else { draws[[draws_list_id]] <- NULL } @@ -342,7 +342,7 @@ read_cmdstan_csv <- function(files, format <- "draws_matrix" } as_draws_format <- as_draws_format_fun(format) - variational_draws <- do.call(as_draws_format, list(draws[[1]][-1, colnames(draws[[1]]) != "lp__", drop=FALSE])) + variational_draws <- do.call(as_draws_format, list(draws[[1]][-1, colnames(draws[[1]]) != "lp__", drop = FALSE])) if (!is.null(variational_draws)) { if ("log_p__" %in% posterior::variables(variational_draws)) { variational_draws <- posterior::rename_variables(variational_draws, lp__ = "log_p__") @@ -361,7 +361,7 @@ read_cmdstan_csv <- function(files, format <- "draws_matrix" } as_draws_format <- as_draws_format_fun(format) - point_estimates <- do.call(as_draws_format, list(draws[[1]][1,, drop=FALSE])) + point_estimates <- do.call(as_draws_format, list(draws[[1]][1, , drop = FALSE])) point_estimates <- posterior::subset_draws(point_estimates, variable = variables) if (!is.null(point_estimates)) { posterior::variables(point_estimates) <- repaired_variables @@ -534,20 +534,16 @@ for (method in unavailable_methods_CmdStanFit_CSV) { #' read_csv_metadata <- function(csv_file) { checkmate::assert_file_exists(csv_file, access = "r", extension = "csv") - adaptation_terminated <- FALSE - param_names_read <- FALSE inv_metric_next <- FALSE - inv_metric_diagonal_next <- FALSE csv_file_info <- list() csv_file_info$inv_metric <- NULL inv_metric_rows_to_read <- -1 inv_metric_rows <- -1 - parsing_done <- FALSE dense_inv_metric <- FALSE diagnose_gradients <- FALSE gradients <- data.frame() warmup_time <- 0 - sampling_time <-0 + sampling_time <- 0 total_time <- 0 if (os_is_windows()) { grep_path <- repair_path(Sys.which("grep.exe")) @@ -562,7 +558,7 @@ read_csv_metadata <- function(csv_file) { stringsAsFactors = FALSE, fill = TRUE, sep = "", - header= FALSE + header = FALSE ) ) if (is.null(metadata) || length(metadata) == 0) { @@ -604,7 +600,7 @@ read_csv_metadata <- function(csv_file) { inv_metric_next <- FALSE } parse_key_val <- FALSE - } else if(diagnose_gradients){ + } else if (diagnose_gradients) { parse_key_val <- FALSE tmp <- gsub("#", "", line, fixed = TRUE) if (nzchar(tmp)) { @@ -665,8 +661,8 @@ read_csv_metadata <- function(csv_file) { } if (inv_metric_rows > 0 && csv_file_info$metric == "dense_e") { rows <- inv_metric_rows - cols <- length(csv_file_info$inv_metric)/inv_metric_rows - dim(csv_file_info$inv_metric) <- c(rows,cols) + cols <- length(csv_file_info$inv_metric) / inv_metric_rows + dim(csv_file_info$inv_metric) <- c(rows, cols) } # rename from old cmdstan names to new cmdstanX names @@ -723,7 +719,7 @@ check_csv_metadata_matches <- function(csv_metadata) { if (!all(method == method[1])) { stop("Supplied CSV files were produced by different methods and need to be read in separately!", call. = FALSE) } - for(i in 2:length(csv_metadata)) { + for (i in 2:length(csv_metadata)) { if (length(csv_metadata[[1]]$model_params) != length(csv_metadata[[i]]$model_params) || !all(csv_metadata[[1]]$model_params == csv_metadata[[i]]$model_params)) { stop("Supplied CSV files have samples for different variables!", call. = FALSE) @@ -736,7 +732,7 @@ check_csv_metadata_matches <- function(csv_metadata) { iter_warmup <- sapply(csv_metadata, function(x) x$iter_warmup) if (!all(iter_sampling == iter_sampling[1]) || !all(thin == thin[1]) || - !all(save_warmup == save_warmup[1])|| + !all(save_warmup == save_warmup[1]) || (save_warmup[1] == 1 && !all(iter_warmup == iter_warmup[1]))) { stop("Supplied CSV files do not match in the number of output samples!", call. = FALSE) } @@ -782,8 +778,8 @@ repair_variable_names <- function(names) { # convert names like beta[1,1] to beta.1.1 unrepair_variable_names <- function(names) { names <- sub("\\[", "\\.", names) - names <- gsub(",","\\.", names) - names <- gsub("\\]","", names) + names <- gsub(",", "\\.", names) + names <- gsub("\\]", "", names) names } @@ -806,7 +802,7 @@ remaining_columns_to_read <- function(requested, currently_read, all) { if (any(all_remaining == p)) { unread <- c(unread, p) } - is_unread_element <- startsWith(all_remaining, paste0(p,"[")) + is_unread_element <- startsWith(all_remaining, paste0(p, "[")) if (any(is_unread_element)) { unread <- c(unread, all_remaining[is_unread_element]) } @@ -818,3 +814,36 @@ remaining_columns_to_read <- function(requested, currently_read, all) { "" } } + +#' Returns a list of dimensions for the input variables. +#' +#' @noRd +#' @param variable_names A character vector of variable names including all +#' individual elements (e.g., `c("beta[1]", "beta[2]")`, not just `"beta"`). +#' @return A list giving the dimensions of the variables. The equivalent of the +#' `par_dims` slot of RStan's stanfit objects, except that scalars have +#' dimension `1` instead of `0`. +#' @note For this function to return the correct dimensions the input must be +#' already sorted in ascending order. Since CmdStan always has the variables +#' sorted correctly we avoid a sort by not sorting again here. +#' +variable_dims <- function(variable_names = NULL) { + if (is.null(variable_names)) { + return(NULL) + } + dims <- list() + uniq_variable_names <- unique(gsub("\\[.*\\]", "", variable_names)) + var_names <- gsub("\\]", "", variable_names) + for (var in uniq_variable_names) { + pattern <- paste0("^", var, "\\[") + var_indices <- var_names[grep(pattern, var_names)] + var_indices <- gsub(pattern, "", var_indices) + if (length(var_indices)) { + var_indices <- strsplit(var_indices[length(var_indices)], ",")[[1]] + dims[[var]] <- as.numeric(var_indices) + } else { + dims[[var]] <- 1 + } + } + dims +} diff --git a/R/data.R b/R/data.R index 3401248bc..0a57e0ed1 100644 --- a/R/data.R +++ b/R/data.R @@ -30,7 +30,7 @@ write_stan_json <- function(data, file) { } if (is.logical(var)) { - mode(var) <- "integer" # convert TRUE/FALSE to 1/0 + mode(var) <- "integer" } else if (is.data.frame(var)) { var <- data.matrix(var) } else if (is.list(var)) { @@ -40,7 +40,6 @@ write_stan_json <- function(data, file) { } # unboxing variables (N = 10 is stored as N : 10, not N: [10]) - # handling factors as integers jsonlite::write_json( data, path = file, @@ -66,7 +65,7 @@ list_to_array <- function(x, name = NULL) { } all_numeric <- all(sapply(x, function(a) is.numeric(a))) if (!all_numeric) { - stop("All elements in list '", name,"' must be numeric!", call. = FALSE) + stop("All elements in list '", name, "' must be numeric!", call. = FALSE) } element_num_of_dim <- length(all_dims[[1]]) x <- unlist(x) @@ -111,9 +110,7 @@ process_data <- function(data) { # check if any objects in the data list have zero as one of their dimensions any_zero_dims <- function(data) { - has_zero_dims <- sapply(data, function(x) { - any(dim(x) == 0) - }) + has_zero_dims <- sapply(data, function(x) any(dim(x) == 0)) any(has_zero_dims) } @@ -153,7 +150,11 @@ draws_to_csv <- function(draws, sampler_diagnostics = NULL) { } else { # create a dummy lp__ if it does not exist lp__ <- posterior::draws_array(lp__ = zeros, .nchains = n_chains) } - all_variables <- c("lp__", posterior::variables(sampler_diagnostics), draws_variables[!(draws_variables %in% c("lp__", "lp_approx__"))]) + all_variables <- c( + "lp__", + posterior::variables(sampler_diagnostics), + draws_variables[!(draws_variables %in% c("lp__", "lp_approx__"))] + ) draws <- posterior::subset_draws( posterior::bind_draws(draws, sampler_diagnostics, lp__, along = "variable"), variable = all_variables @@ -192,31 +193,34 @@ draws_to_csv <- function(draws, sampler_diagnostics = NULL) { process_fitted_params <- function(fitted_params) { if (is.character(fitted_params)) { paths <- absolute_path(fitted_params) - } else if (checkmate::test_r6(fitted_params, classes = "CmdStanMCMC") && - all(file.exists(fitted_params$output_files()))) { + } else if (checkmate::test_r6(fitted_params, "CmdStanMCMC") && + all(file.exists(fitted_params$output_files()))) { paths <- absolute_path(fitted_params$output_files()) - } else if(checkmate::test_r6(fitted_params, classes = c("CmdStanMCMC"))) { - draws <- tryCatch(fitted_params$draws(), - error=function(cond) { - stop("Unable to obtain draws from the fit object.", call. = FALSE) + } else if (checkmate::test_r6(fitted_params, "CmdStanMCMC")) { + draws <- tryCatch( + fitted_params$draws(), + error = function(cond) { + stop("Unable to obtain draws from the fit object.", call. = FALSE) } ) - sampler_diagnostics <- tryCatch(fitted_params$sampler_diagnostics(), - error=function(cond) { - NULL + sampler_diagnostics <- tryCatch( + fitted_params$sampler_diagnostics(), + error = function(cond) { + NULL } ) paths <- draws_to_csv(draws, sampler_diagnostics) - } else if(checkmate::test_r6(fitted_params, classes = c("CmdStanVB"))) { - draws <- tryCatch(fitted_params$draws(), - error=function(cond) { - stop("Unable to obtain draws from the fit object.", call. = FALSE) + } else if (checkmate::test_r6(fitted_params, "CmdStanVB")) { + draws <- tryCatch( + fitted_params$draws(), + error = function(cond) { + stop("Unable to obtain draws from the fit object.", call. = FALSE) } ) paths <- draws_to_csv(posterior::as_draws_array(draws)) - } else if (any(class(fitted_params) == "draws_array")){ + } else if (any(class(fitted_params) == "draws_array")) { paths <- draws_to_csv(fitted_params) - } else if (any(class(fitted_params) == "draws_matrix")){ + } else if (any(class(fitted_params) == "draws_matrix")) { paths <- draws_to_csv(posterior::as_draws_array(fitted_params)) } else { stop( diff --git a/R/example.R b/R/example.R index 773946189..237198935 100644 --- a/R/example.R +++ b/R/example.R @@ -91,18 +91,29 @@ print_example_program <- #' Write Stan code to a file #' #' Convenience function for writing Stan code to a (possibly -#' [temporary][base::tempfile]) file with a `.stan` extension. +#' [temporary][base::tempfile]) file with a `.stan` extension. By default, the +#' file name is chosen deterministically based on a [hash][rlang::hash()] +#' of the Stan code, and the file is not overwritten if it already has correct +#' contents. This means that calling this function multiple times with the same +#' Stan code will reuse the compiled model. This also however means that the +#' function is potentially not thread-safe. Using `hash_salt = Sys.getpid()` +#' should ensure thread-safety in the rare cases when it is needed. #' #' @export -#' @param code (multiple options) The Stan code: -#' * A single string containing a Stan program -#' * A character vector containing the individual lines of a Stan program. +#' @param code (character vector) The Stan code to write to the file. This can +#' be a character vector of length one (a string) containing the entire Stan +#' program or a character vector with each element containing one line of the +#' Stan program. #' @param dir (string) An optional path to the directory where the file will be #' written. If omitted, a [temporary directory][base::tempdir] is used by #' default. #' @param basename (string) If `dir` is specified, optionally the basename to -#' use for the file created. If not specified a file name is generated via -#' [base::tempfile()]. +#' use for the file created. If not specified a file name is generated +#' from [hashing][rlang::hash()] the code. +#' @param force_overwrite (logical) If set to `TRUE` the file will always be +#' overwritten and thus the resulting model will always be recompiled. +#' @param hash_salt (string) Text to add to the model code prior to hashing to +#' determine the file name if `basename` is not set. #' @return The path to the file. #' #' @examples @@ -131,19 +142,39 @@ print_example_program <- #' f2 <- write_stan_file(lines) #' identical(readLines(f), readLines(f2)) #' -write_stan_file <- function(code, dir = tempdir(), basename = NULL) { +write_stan_file <- function(code, dir = tempdir(), basename = NULL, + force_overwrite = FALSE, hash_salt = "") { if (!dir.exists(dir)) { dir.create(dir, recursive = TRUE) } + collapsed_code <- paste0(code, collapse = "\n") + if (!is.null(basename)) { if (!endsWith(basename, ".stan")) { basename <- paste0(basename, ".stan") } file <- file.path(dir, basename) } else { - file <- tempfile(fileext = ".stan", tmpdir = dir) + hash <- rlang::hash(paste0(hash_salt, collapsed_code)) + file <- file.path(dir, paste0("model_", hash, ".stan")) + } + overwrite <- TRUE + # Do not overwrite file if it has the correct contents (to avoid recompilation) + if (!force_overwrite && file.exists(file)) { + tryCatch({ + file_contents <- paste0(readLines(file), collapse = "\n") + if (gsub("\r|\n", "\n", file_contents) == gsub("\r|\n", "\n", collapsed_code)) { + overwrite <- FALSE + } + }, + error = function(e) { + warning("Error when checking old file contents", e) + }) + } + + if (overwrite) { + cat(code, file = file, sep = "\n") } - cat(code, file = file, sep = "\n") file } @@ -159,4 +190,3 @@ write_stan_tempfile <- function(code, dir = tempdir()) { call. = FALSE) write_stan_file(code, dir) } - diff --git a/R/fit.R b/R/fit.R index 7f6345d21..a9e2840ab 100644 --- a/R/fit.R +++ b/R/fit.R @@ -638,7 +638,7 @@ CmdStanFit$set("public", name = "time", value = time) output <- function(id = NULL) { # MCMC has separate implementation but doc is shared # Non-MCMC fit is obtained with one process only so id is ignored - cat(paste(self$runset$procs$proc_output(1), collapse="\n")) + cat(paste(self$runset$procs$proc_output(1), collapse = "\n")) } CmdStanFit$set("public", name = "output", value = output) @@ -843,7 +843,7 @@ CmdStanMCMC <- R6::R6Class( if (is.null(id)) { self$runset$procs$proc_output() } else { - cat(paste(self$runset$procs$proc_output(id), collapse="\n")) + cat(paste(self$runset$procs$proc_output(id), collapse = "\n")) } }, @@ -877,7 +877,7 @@ CmdStanMCMC <- R6::R6Class( variables <- matching_res$matching } if (inc_warmup) { - posterior::subset_draws(posterior::bind_draws(private$warmup_draws_, private$draws_, along="iteration"), variable = variables) + posterior::subset_draws(posterior::bind_draws(private$warmup_draws_, private$draws_, along = "iteration"), variable = variables) } else { posterior::subset_draws(private$draws_, variable = variables) } @@ -910,7 +910,7 @@ CmdStanMCMC <- R6::R6Class( private$draws_ <- posterior::bind_draws( private$draws_, posterior::subset_draws(csv_contents$post_warmup_draws, variable = missing_variables), - along="variable" + along = "variable" ) } } @@ -923,7 +923,7 @@ CmdStanMCMC <- R6::R6Class( private$sampler_diagnostics_ <- posterior::bind_draws( private$sampler_diagnostics_, posterior::subset_draws(csv_contents$post_warmup_sampler_diagnostics, variable = missing_variables), - along="variable" + along = "variable" ) } } @@ -937,7 +937,7 @@ CmdStanMCMC <- R6::R6Class( private$warmup_draws_ <- posterior::bind_draws( private$warmup_draws_, posterior::subset_draws(csv_contents$warmup_draws, variable = missing_variables), - along="variable" + along = "variable" ) } } @@ -949,7 +949,7 @@ CmdStanMCMC <- R6::R6Class( private$warmup_sampler_diagnostics_ <- posterior::bind_draws( private$warmup_sampler_diagnostics_, posterior::subset_draws(csv_contents$warmup_sampler_diagnostics, variable = missing_variables), - along="variable" + along = "variable" ) } } @@ -1069,7 +1069,7 @@ sampler_diagnostics <- function(inc_warmup = FALSE, format = getOption("cmdstanr posterior::bind_draws( private$warmup_sampler_diagnostics_, private$sampler_diagnostics_, - along="iteration" + along = "iteration" ) } else { private$sampler_diagnostics_ @@ -1134,7 +1134,7 @@ CmdStanMCMC$set("public", name = "inv_metric", value = inv_metric) #' fit_mcmc$num_chains() #' } #' -num_chains = function() { +num_chains <- function() { super$num_procs() } CmdStanMCMC$set("public", name = "num_chains", value = num_chains) @@ -1403,7 +1403,7 @@ CmdStanGQ <- R6::R6Class( if (is.null(id)) { self$runset$procs$proc_output() } else { - cat(paste(self$runset$procs$proc_output(id), collapse="\n")) + cat(paste(self$runset$procs$proc_output(id), collapse = "\n")) } } ), @@ -1426,7 +1426,7 @@ CmdStanGQ <- R6::R6Class( posterior::bind_draws( private$draws_, posterior::subset_draws(csv_contents$generated_quantities, variable = missing_variables), - along="variable" + along = "variable" ) } invisible(self) diff --git a/R/install.R b/R/install.R index 267c57100..20e3cd2d9 100644 --- a/R/install.R +++ b/R/install.R @@ -18,6 +18,11 @@ #' switches, changing the C++ compiler, etc. A change to the `make/local` file #' should typically be followed by calling `rebuild_cmdstan()`. #' +#' The `check_cmdstan_toolchain()` function attempts to check for the required +#' C++ toolchain. It is called internally by `install_cmdstan()` but can also +#' be called directly by the user. +#' +#' #' @export #' @param dir (string) The path to the directory in which to install CmdStan. #' The default is to install it in a directory called `.cmdstanr` within the @@ -105,8 +110,8 @@ install_cmdstan <- function(dir = NULL, message("* Installing CmdStan from ", release_url) download_url <- release_url split_url <- strsplit(release_url, "/") - tar_name <- utils::tail(split_url[[1]], n=1) - cmdstan_ver <- substr(tar_name, 0, nchar(tar_name)-7) + tar_name <- utils::tail(split_url[[1]], n = 1) + cmdstan_ver <- substr(tar_name, 0, nchar(tar_name) - 7) tar_gz_file <- paste0(cmdstan_ver, ".tar.gz") dir_cmdstan <- file.path(dir, cmdstan_ver) dest_file <- file.path(dir, tar_gz_file) @@ -165,7 +170,7 @@ install_cmdstan <- function(dir = NULL, cmdstan_make_local( dir = dir_cmdstan, cpp_options = list( - CXX="arch -arch arm64e clang++" + CXX = "arch -arch arm64e clang++" ), append = TRUE ) @@ -212,18 +217,18 @@ cmdstan_make_local <- function(dir = cmdstan_path(), append = TRUE) { make_local_path <- file.path(dir, "make", "local") if (!is.null(cpp_options)) { - built_flags = c() + built_flags <- c() for (i in seq_len(length(cpp_options))) { option_name <- names(cpp_options)[i] if (isTRUE(as.logical(cpp_options[[i]]))) { - built_flags = c(built_flags, paste0(option_name, "=true")) + built_flags <- c(built_flags, paste0(option_name, "=true")) } else if (isFALSE(as.logical(cpp_options[[i]]))) { - built_flags = c(built_flags, paste0(option_name, "=false")) + built_flags <- c(built_flags, paste0(option_name, "=false")) } else { if (is.null(option_name) || !nzchar(option_name)) { - built_flags = c(built_flags, paste0(cpp_options[[i]])) + built_flags <- c(built_flags, paste0(cpp_options[[i]])) } else { - built_flags = c(built_flags, paste0(option_name, "=", cpp_options[[i]])) + built_flags <- c(built_flags, paste0(option_name, "=", cpp_options[[i]])) } } } @@ -236,6 +241,34 @@ cmdstan_make_local <- function(dir = cmdstan_path(), } } +#' @rdname install_cmdstan +#' @export +#' @param fix For `check_cmdstan_toolchain()`, should CmdStanR attempt to fix +#' any detected toolchain problems? Currently this option is only available on +#' Windows. The default is `FALSE`, in which case problems are only reported +#' along with suggested fixes. +#' +check_cmdstan_toolchain <- function(fix = FALSE, quiet = FALSE) { + if (os_is_windows()) { + if (R.version$major >= "4") { + check_rtools40_windows_toolchain(fix = fix, quiet = quiet) + } else { + check_rtools35_windows_toolchain(fix = fix, quiet = quiet) + } + } else { + check_unix_make() + check_unix_cpp_compiler() + } + if (!checkmate::test_directory(dirname(tempdir()), access = "w")) { + stop("No write permissions to the temporary folder! Please change the permissions or location of the temporary folder.", call. = FALSE) + } + if (!quiet) { + message("The C++ toolchain required for CmdStan is setup properly!") + } + invisible(NULL) +} + + # internal ---------------------------------------------------------------- check_install_dir <- function(dir_cmdstan, overwrite = FALSE) { @@ -320,8 +353,8 @@ build_cmdstan <- function(dir, timeout) { translation_args <- NULL if (is_rosetta2()) { - run_cmd <- '/usr/bin/arch' - translation_args <- c('-arch', 'arm64e', 'make') + run_cmd <- "/usr/bin/arch" + translation_args <- c("-arch", "arm64e", "make") } else { run_cmd <- make_cmd() } @@ -333,7 +366,7 @@ build_cmdstan <- function(dir, echo = !quiet || is_verbose_mode(), spinner = quiet, error_on_status = FALSE, - stderr_line_callback = function(x,p) { if (quiet) message(x) }, + stderr_line_callback = function(x, p) { if (quiet) message(x) }, timeout = timeout ) } @@ -379,7 +412,7 @@ clean_cmdstan <- function(dir = cmdstan_path(), echo = !quiet || is_verbose_mode(), spinner = quiet, error_on_status = FALSE, - stderr_line_callback = function(x,p) { if (quiet) message(x) } + stderr_line_callback = function(x, p) { if (quiet) message(x) } ) clean_compile_helper_files() } @@ -393,7 +426,7 @@ build_example <- function(dir, cores, quiet, timeout) { echo = !quiet || is_verbose_mode(), spinner = quiet, error_on_status = FALSE, - stderr_line_callback = function(x,p) { if (quiet) message(x) }, + stderr_line_callback = function(x, p) { if (quiet) message(x) }, timeout = timeout ) } @@ -443,7 +476,7 @@ install_mingw32_make <- function(quiet = FALSE) { if (!quiet) message("Installing mingw32-make and writing RTools path to ~/.Renviron ...") processx::run( "pacman", - args = c("-Syu", "mingw-w64-x86_64-make","--noconfirm"), + args = c("-Syu", "mingw-w64-x86_64-make", "--noconfirm"), wd = rtools_usr_bin, error_on_status = TRUE, echo_cmd = is_verbose_mode(), @@ -451,7 +484,7 @@ install_mingw32_make <- function(quiet = FALSE) { ) write('PATH="${RTOOLS40_HOME}\\usr\\bin;${RTOOLS40_HOME}\\mingw64\\bin;${PATH}"', file = "~/.Renviron", append = TRUE) Sys.setenv(PATH = paste0(Sys.getenv("RTOOLS40_HOME"), "\\usr\\bin;", Sys.getenv("RTOOLS40_HOME"), "\\mingw64\\bin;", Sys.getenv("PATH"))) - invisible(NULL) + invisible(NULL) } check_rtools40_windows_toolchain <- function(fix = FALSE, quiet = FALSE) { @@ -487,7 +520,7 @@ check_rtools40_windows_toolchain <- function(fix = FALSE, quiet = FALSE) { } else { install_mingw32_make(quiet = quiet) check_rtools40_windows_toolchain(fix = FALSE, quiet = quiet) - return(invisible(NULL)) + return(invisible(NULL)) } } # Check if the mingw32-make and g++ get picked up by default are the RTools-supplied ones @@ -509,7 +542,9 @@ check_rtools40_windows_toolchain <- function(fix = FALSE, quiet = FALSE) { check_rtools35_windows_toolchain <- function(fix = FALSE, quiet = FALSE, paths = NULL) { - if (is.null(paths)) paths <- c( file.path("C:/", "Rtools"), file.path("C:/", "Rtools35")) + if (is.null(paths)) { + paths <- c(file.path("C:/", "Rtools"), file.path("C:/", "Rtools35")) + } mingw32_make_path <- dirname(Sys.which("mingw32-make")) gpp_path <- dirname(Sys.which("g++")) # If mingw32-make and g++ are not found, we check typical RTools 3.5 folders. @@ -546,13 +581,13 @@ check_rtools35_windows_toolchain <- function(fix = FALSE, message("Writing RTools path to ~/.Renviron ...") } if (!nzchar(Sys.getenv("RTOOLS35_HOME"))) { - write(paste0('RTOOLS35_HOME=', rtools_path), file = "~/.Renviron", append = TRUE) - Sys.setenv(RTOOLS35_HOME = rtools_path) + write(paste0("RTOOLS35_HOME=", rtools_path), file = "~/.Renviron", append = TRUE) + Sys.setenv(RTOOLS35_HOME = rtools_path) } write('PATH="${RTOOLS35_HOME}\\bin;${RTOOLS35_HOME}\\mingw_64\\bin;${PATH}"', file = "~/.Renviron", append = TRUE) - Sys.setenv(PATH = paste0(Sys.getenv("RTOOLS35_HOME"), "\\mingw_64\\bin;", Sys.getenv("PATH"))) + Sys.setenv(PATH = paste0(Sys.getenv("RTOOLS35_HOME"), "\\mingw_64\\bin;", Sys.getenv("PATH"))) check_rtools35_windows_toolchain(fix = FALSE, quiet = quiet) - return(invisible(NULL)) + return(invisible(NULL)) } else { stop( "\nA toolchain was not found. Please install RTools 3.5 and run", @@ -570,9 +605,19 @@ check_unix_make <- function() { make_path <- dirname(Sys.which("make")) if (!nzchar(make_path)) { if (os_is_macos()) { - stop("The 'make' tool was not found. Please install the command line tools for Mac with 'xcode-select --install' or install Xcode from the app store. Then restart R and run check_cmdstan_toolchain().", call. = FALSE) + stop( + "The 'make' tool was not found. ", + "Please install the command line tools for Mac with 'xcode-select --install' ", + "or install Xcode from the app store. ", + "Then restart R and run check_cmdstan_toolchain().", + call. = FALSE + ) } else { - stop("The 'make' tool was not found. Please install 'make', restart R, and then run check_cmdstan_toolchain().", call. = FALSE) + stop( + "The 'make' tool was not found. ", + "Please install 'make', restart R, and then run check_cmdstan_toolchain().", + call. = FALSE + ) } } @@ -583,36 +628,20 @@ check_unix_cpp_compiler <- function() { clang_path <- dirname(Sys.which("clang++")) if (!nzchar(gpp_path) && !nzchar(clang_path)) { if (os_is_macos()) { - stop("A suitable C++ compiler was not found. Please install the command line tools for Mac with 'xcode-select --install' or install Xcode from the app store. Then restart R and run check_cmdstan_toolchain().", call. = FALSE) - } else { - stop("A C++ compiler was not found. Please install the 'clang++' or 'g++' compiler, restart R, and run check_cmdstan_toolchain().", call. = FALSE) - } - } -} - -#' @rdname install_cmdstan -#' @export -#' @param fix For `check_cmdstan_toolchain()`, should CmdStanR attempt to fix -#' any detected toolchain problems? Currently this option is only available on -#' Windows. The default is `FALSE`, in which case problems are only reported -#' along with suggested fixes. -#' -check_cmdstan_toolchain <- function(fix = FALSE, quiet = FALSE) { - if (os_is_windows()) { - if (R.version$major >= "4") { - check_rtools40_windows_toolchain(fix = fix, quiet = quiet) + stop( + "A suitable C++ compiler was not found. ", + "Please install the command line tools for Mac with 'xcode-select --install' ", + "or install Xcode from the app store. ", + "Then restart R and run check_cmdstan_toolchain().", + call. = FALSE + ) } else { - check_rtools35_windows_toolchain(fix = fix, quiet = quiet) + stop( + "A C++ compiler was not found. ", + "Please install the 'clang++' or 'g++' compiler, restart R, ", + "and run check_cmdstan_toolchain().", + call. = FALSE + ) } - } else { - check_unix_make() - check_unix_cpp_compiler() - } - if (!checkmate::test_directory(dirname(tempdir()), access = "w")) { - stop("No write permissions to the temporary folder! Please change the permissions or location of the temporary folder.", call. = FALSE) } - if (!quiet) { - message("The C++ toolchain required for CmdStan is setup properly!") - } - invisible(NULL) } diff --git a/R/knitr.R b/R/knitr.R index d653d9dff..93a00b338 100644 --- a/R/knitr.R +++ b/R/knitr.R @@ -90,4 +90,3 @@ eng_cmdstan <- function(options) { code <- paste(options$code, collapse = "\n") knitr::engine_output(options, code, '') } - diff --git a/R/model.R b/R/model.R index 15cf55e96..5a80efb70 100644 --- a/R/model.R +++ b/R/model.R @@ -199,14 +199,13 @@ CmdStanModel <- R6::R6Class( ), public = list( initialize = function(stan_file, compile, ...) { + args <- list(...) checkmate::assert_file_exists(stan_file, access = "r", extension = "stan") checkmate::assert_flag(compile) private$stan_file_ <- absolute_path(stan_file) private$model_name_ <- sub(" ", "_", strip_ext(basename(private$stan_file_))) - args <- list(...) - check_stanc_options(args$stanc_options) private$precompile_cpp_options_ <- args$cpp_options %||% list() - private$precompile_stanc_options_ <- args$stanc_options %||% list() + private$precompile_stanc_options_ <- assert_valid_stanc_options(args$stanc_options) %||% list() private$precompile_include_paths_ <- args$include_paths private$dir_ <- args$dir @@ -368,7 +367,7 @@ compile <- function(quiet = TRUE, if (length(stanc_options) == 0 && !is.null(private$precompile_stanc_options_)) { stanc_options <- private$precompile_stanc_options_ } - check_stanc_options(stanc_options) + stanc_options <- assert_valid_stanc_options(stanc_options) if (is.null(include_paths) && !is.null(private$precompile_include_paths_)) { include_paths <- private$precompile_include_paths_ } @@ -502,7 +501,7 @@ compile <- function(quiet = TRUE, echo = !quiet || is_verbose_mode(), echo_cmd = is_verbose_mode(), spinner = quiet && interactive(), - stderr_line_callback = function(x,p) { + stderr_line_callback = function(x, p) { if (!startsWith(x, paste0(make_cmd(), ": *** No rule to make target"))) { message(x) } @@ -622,7 +621,7 @@ check_syntax <- function(pedantic = FALSE, if (is.null(stanc_options[["name"]])) { stanc_options[["name"]] <- paste0(self$model_name(), "_model") } - stanc_built_options = c() + stanc_built_options <- c() for (i in seq_len(length(stanc_options))) { option_name <- names(stanc_options)[i] if (isTRUE(as.logical(stanc_options[[i]]))) { @@ -641,10 +640,10 @@ check_syntax <- function(pedantic = FALSE, echo = is_verbose_mode(), echo_cmd = is_verbose_mode(), spinner = quiet && interactive(), - stdout_line_callback = function(x,p) { + stdout_line_callback = function(x, p) { if (!quiet) cat(x) }, - stderr_line_callback = function(x,p) { + stderr_line_callback = function(x, p) { message(x) }, error_on_status = FALSE @@ -1353,3 +1352,38 @@ assert_valid_threads <- function(threads, cpp_options, multiple_chains = FALSE) invisible(threads) } +assert_valid_stanc_options <- function(stanc_options) { + i <- 1 + names <- names(stanc_options) + for (s in stanc_options) { + if (!is.null(names[i]) && nzchar(names[i])) { + name <- names[i] + } else { + name <- s + } + if (startsWith(name, "--")) { + stop("No leading hyphens allowed in stanc options (", name, "). ", + "Use options without leading hyphens, for example ", + "`stanc_options = list('allow-undefined')`", + call. = FALSE) + } + i <- i + 1 + } + invisible(stanc_options) +} + +cpp_options_to_compile_flags <- function(cpp_options) { + if (length(cpp_options) == 0) { + return(NULL) + } + cpp_built_options <- c() + for (i in seq_along(cpp_options)) { + option_name <- names(cpp_options)[i] + if (is.null(option_name) || !nzchar(option_name)) { + cpp_built_options <- c(cpp_built_options, toupper(cpp_options[[i]])) + } else { + cpp_built_options <- c(cpp_built_options, paste0(toupper(option_name), "=", cpp_options[[i]])) + } + } + cpp_built_options +} diff --git a/R/path.R b/R/path.R index 9def9da50..8d37c6142 100644 --- a/R/path.R +++ b/R/path.R @@ -22,8 +22,8 @@ #' * If the [environment variable][Sys.setenv()] `"CMDSTAN"` exists at load time #' then its value will be automatically set as the default path to CmdStan for #' the \R session. -#' * If no environment variable is found when loaded but any directory in the form -#' `".cmdstanr/cmdstan-[version]"`, for example `".cmdstanr/cmdstan-2.23.0"`, +#' * If no environment variable is found when loaded but any directory in the +#' form `".cmdstanr/cmdstan-[version]"` (e.g., `".cmdstanr/cmdstan-2.23.0"`), #' exists in the user's home directory (`Sys.getenv("HOME")`, *not* the current #' working directory) then the path to the cmdstan with the largest version #' number will be set as the path to CmdStan for the \R session. This is the @@ -108,22 +108,23 @@ cmdstan_default_install_path <- function() { #' cmdstan_default_path #' -#' Returns the path to the installation of cmdstan with the most recent release version. +#' Returns the path to the installation of CmdStan with the most recent release +#' version. #' -#' @keywords internal -#' @return Path to the cmdstan installation with the most recent release version, NULL if no -#' installation found. #' @export +#' @keywords internal +#' @return Path to the CmdStan installation with the most recent release +#' version, or `NULL` if no installation found. +#' cmdstan_default_path <- function() { - installs_path <- file.path(Sys.getenv("HOME"), ".cmdstanr") + installs_path <- cmdstan_default_install_path() if (dir.exists(installs_path)) { cmdstan_installs <- list.dirs(path = installs_path, recursive = FALSE, full.names = FALSE) - # if installed in folder cmdstan, with no version - # move to cmdstan-version folder + # if installed in cmdstan folder with no version move to cmdstan-version folder if ("cmdstan" %in% cmdstan_installs) { ver <- read_cmdstan_version(file.path(installs_path, "cmdstan")) old_path <- file.path(installs_path, "cmdstan") - new_path <- file.path(installs_path, paste0("cmdstan-",ver)) + new_path <- file.path(installs_path, paste0("cmdstan-", ver)) file.rename(old_path, new_path) cmdstan_installs <- list.dirs(path = installs_path, recursive = FALSE, full.names = FALSE) } @@ -131,14 +132,14 @@ cmdstan_default_path <- function() { latest_cmdstan <- sort(cmdstan_installs, decreasing = TRUE)[1] if (is_release_candidate(latest_cmdstan)) { non_rc_path <- strsplit(latest_cmdstan, "-rc")[[1]][1] - if (dir.exists(file.path(installs_path,non_rc_path))) { + if (dir.exists(file.path(installs_path, non_rc_path))) { latest_cmdstan <- non_rc_path } } - return(file.path(installs_path,latest_cmdstan)) + return(file.path(installs_path, latest_cmdstan)) } } - return(NULL) + NULL } # unset the path (only used in tests) @@ -148,7 +149,7 @@ unset_cmdstan_path <- function() { } -#' Find the version of cmdstan from makefile +#' Find the version of CmdStan from makefile #' @noRd #' @param path Path to installation. #' @return Version number as a string. diff --git a/R/run.R b/R/run.R index 60a52fe94..ca9df9072 100644 --- a/R/run.R +++ b/R/run.R @@ -223,7 +223,7 @@ CmdStanRun <- R6::R6Class( stop("No CmdStan runs finished successfully. ", "Unable to run bin/", tool, ".", call. = FALSE) } - target_exe = file.path("bin", cmdstan_ext(tool)) + target_exe <- file.path("bin", cmdstan_ext(tool)) check_target_exe(target_exe) run_log <- processx::run( command = target_exe, @@ -303,7 +303,7 @@ check_target_exe <- function(exe) { on.exit(procs$cleanup(), add = TRUE) if (!is.null(mpi_cmd)) { if (is.null(mpi_args)) { - mpi_args = list() + mpi_args <- list() } mpi_args[["exe"]] <- self$exe_file() } @@ -560,11 +560,11 @@ CmdStanProcs <- R6::R6Class( private$proc_ids_ <- seq_len(num_procs) zeros <- rep(0, num_procs) names(zeros) <- private$proc_ids_ - private$proc_state_ = zeros - private$proc_start_time_ = zeros - private$proc_total_time_ = zeros - private$show_stderr_messages_ = show_stderr_messages - private$show_stdout_messages_ = show_stdout_messages + private$proc_state_ <- zeros + private$proc_start_time_ <- zeros + private$proc_total_time_ <- zeros + private$show_stderr_messages_ <- show_stderr_messages + private$show_stdout_messages_ <- show_stdout_messages invisible(self) }, num_procs = function() { @@ -602,7 +602,7 @@ CmdStanProcs <- R6::R6Class( for (i in names(mpi_args)) { mpi_args_vector <- c(paste0("-", i), mpi_args[[i]], mpi_args_vector) } - args = c(mpi_args_vector, exe_name, args) + args <- c(mpi_args_vector, exe_name, args) command <- mpi_cmd } private$processes_[[id]] <- processx::process$new( @@ -868,11 +868,11 @@ CmdStanMCMCProcs <- R6::R6Class( next_state <- 5 state <- 5 } - if (grepl("Gradient evaluation took",line, fixed = TRUE) - || grepl("leapfrog steps per transition would take",line, fixed = TRUE) - || grepl("Adjust your expectations accordingly!",line, fixed = TRUE) - || grepl("stanc_version",line, fixed = TRUE) - || grepl("stancflags",line, fixed = TRUE)) { + if (grepl("Gradient evaluation took", line, fixed = TRUE) + || grepl("leapfrog steps per transition would take", line, fixed = TRUE) + || grepl("Adjust your expectations accordingly!", line, fixed = TRUE) + || grepl("stanc_version", line, fixed = TRUE) + || grepl("stancflags", line, fixed = TRUE)) { ignore_line <- TRUE } if ((state > 1.5 && state < 5 && !ignore_line) || is_verbose_mode()) { @@ -1049,4 +1049,4 @@ check_tbb_path <- function() { Sys.setenv(PATH = paste0(path_to_TBB, ";", Sys.getenv("PATH"))) } } -} \ No newline at end of file +} diff --git a/R/utils.R b/R/utils.R index bf41e2d66..32a161616 100644 --- a/R/utils.R +++ b/R/utils.R @@ -16,6 +16,24 @@ is_verbose_mode <- function() { if (is.null(x) || length(x) == 0) y else x } +# used in both fit.R and csv.R for variable filtering +matching_variables <- function(variable_filters, variables) { + not_found <- c() + selected_variables <- c() + for (v in variable_filters) { + selected <- variables == v | startsWith(variables, paste0(v, "[")) + selected_variables <- c(selected_variables, variables[selected]) + variables <- variables[!selected] + if (!any(selected)) { + not_found <- c(not_found, v) + } + } + list( + matching = selected_variables, + not_found = not_found + ) +} + # checks for OS and hardware ---------------------------------------------- @@ -183,7 +201,7 @@ generate_file_names <- } if (random) { rand_num_pid <- as.integer(stats::runif(1, min = 0, max = 1E7)) + Sys.getpid() - rand <- format(as.hexmode(rand_num_pid) , width = 6) + rand <- format(as.hexmode(rand_num_pid), width = 6) new_names <- paste0(new_names, "-", rand) } @@ -194,41 +212,8 @@ generate_file_names <- new_names } +# threading helpers (deprecated) ------------------------------------------ -cpp_options_to_compile_flags <- function(cpp_options) { - if (length(cpp_options) == 0) { - return(NULL) - } - cpp_built_options = c() - for (i in seq_len(length(cpp_options))) { - option_name <- names(cpp_options)[i] - if (is.null(option_name) || !nzchar(option_name)) { - cpp_built_options = c(cpp_built_options, toupper(cpp_options[[i]])) - } else { - cpp_built_options = c(cpp_built_options, paste0(toupper(option_name), "=", cpp_options[[i]])) - } - } - cpp_built_options -} - -check_stanc_options <- function(stanc_options) { - i <- 1 - names <- names(stanc_options) - for (s in stanc_options){ - if (!is.null(names[i]) && nzchar(names[i])) { - name <- names[i] - } else { - name <- s - } - if (startsWith(name, "--")) { - stop("No leading hyphens allowed in stanc options (", name, "). ", - "Use options without leading hyphens, like for example ", - "`stanc_options = list('allow-undefined')`", - call. = FALSE) - } - i <- i + 1 - } -} #' Set or get the number of threads used to execute Stan models #' @@ -253,13 +238,15 @@ set_num_threads <- function(num_threads) { call. = FALSE) } + +# convergence checks ------------------------------------------------------ check_divergences <- function(post_warmup_sampler_diagnostics) { if (!is.null(post_warmup_sampler_diagnostics)) { divergences <- posterior::extract_variable_matrix(post_warmup_sampler_diagnostics, "divergent__") num_of_draws <- length(divergences) num_of_divergences <- sum(divergences) if (!is.na(num_of_divergences) && num_of_divergences > 0) { - percentage_divergences <- (num_of_divergences)/num_of_draws*100 + percentage_divergences <- 100 * num_of_divergences / num_of_draws message( "\nWarning: ", num_of_divergences, " of ", num_of_draws, " (", (format(round(percentage_divergences, 0), nsmall = 1)), "%)", @@ -280,13 +267,15 @@ check_sampler_transitions_treedepth <- function(post_warmup_sampler_diagnostics, num_of_draws <- length(treedepth) max_treedepth_hit <- sum(treedepth >= metadata$max_treedepth) if (!is.na(max_treedepth_hit) && max_treedepth_hit > 0) { - percentage_max_treedepth <- (max_treedepth_hit)/num_of_draws*100 - message(max_treedepth_hit, " of ", num_of_draws, " (", (format(round(percentage_max_treedepth, 0), nsmall = 1)), "%)", - " transitions hit the maximum treedepth limit of ", metadata$max_treedepth, - " or 2^", metadata$max_treedepth, "-1 leapfrog steps.\n", - "Trajectories that are prematurely terminated due to this limit will result in slow exploration.\n", - "Increasing the max_treedepth limit can avoid this at the expense of more computation.\n", - "If increasing max_treedepth does not remove warnings, try to reparameterize the model.\n") + percentage_max_treedepth <- 100 * max_treedepth_hit / num_of_draws + message( + max_treedepth_hit, " of ", num_of_draws, " (", (format(round(percentage_max_treedepth, 0), nsmall = 1)), "%)", + " transitions hit the maximum treedepth limit of ", metadata$max_treedepth, + " or 2^", metadata$max_treedepth, "-1 leapfrog steps.\n", + "Trajectories that are prematurely terminated due to this limit will result in slow exploration.\n", + "Increasing the max_treedepth limit can avoid this at the expense of more computation.\n", + "If increasing max_treedepth does not remove warnings, try to reparameterize the model.\n" + ) } } } @@ -299,61 +288,13 @@ check_bfmi <- function(post_warmup_sampler_diagnostics) { }) if (any(ebfmi < .3)) { message(sum(ebfmi < .3), " of ", length(ebfmi) , " chains had estimated Bayesian fraction - of missing information(E-BFMI) less than 0.3, which may indicate poor exploration of the + of missing information(E-BFMI) less than 0.3, which may indicate poor exploration of the posterior. Try to reparameterize the model.") } } } -matching_variables <- function(variable_filters, variables) { - not_found <- c() - selected_variables <- c() - for(v in variable_filters) { - selected <- variables == v | startsWith(variables, paste0(v, "[")) - selected_variables <- c(selected_variables, variables[selected]) - variables <- variables[!selected] - if (!any(selected)) { - not_found <- c(not_found, v) - } - } - list( - matching = selected_variables, - not_found = not_found - ) -} - -#' Returns a list of dimensions for the input variables. -#' -#' @noRd -#' @param variable_names A character vector of variable names including all -#' individual elements (e.g., `c("beta[1]", "beta[2]")`, not just `"beta"`). -#' @return A list giving the dimensions of the variables. The equivalent of the -#' `par_dims` slot of RStan's stanfit objects, except that scalars have -#' dimension `1` instead of `0`. -#' @note For this function to return the correct dimensions the input must be -#' already sorted in ascending order. Since CmdStan always has the variables -#' sorted correctly we avoid a sort by not sorting again here. -#' -variable_dims <- function(variable_names = NULL) { - if (is.null(variable_names)) { - return(NULL) - } - dims <- list() - uniq_variable_names <- unique(gsub("\\[.*\\]", "", variable_names)) - var_names <- gsub("\\]", "", variable_names) - for (var in uniq_variable_names) { - pattern <- paste0("^", var, "\\[") - var_indices <- var_names[grep(pattern, var_names)] - var_indices <- gsub(pattern, "", var_indices) - if (length(var_indices)) { - var_indices <- strsplit(var_indices[length(var_indices)], ",")[[1]] - dims[[var]] <- as.numeric(var_indices) - } else { - dims[[var]] <- 1 - } - } - dims -} +# draws formatting -------------------------------------------------------- as_draws_format_fun <- function(draws_format) { if (draws_format %in% c("draws_array", "array")) { @@ -384,26 +325,18 @@ valid_draws_formats <- function() { "draws_list", "list", "draws_df", "df", "data.frame") } -maybe_convert_draws_format <- function(draws, draws_format) { - if (!is.null(draws)) { - if (draws_format %in% c("draws_array", "array")) { - if (!posterior::is_draws_array(draws)) { - draws <- posterior::as_draws_array(draws) - } - } else if (draws_format %in% c("draws_df", "df", "data.frame")) { - if (!posterior::is_draws_df(draws)) { - draws <- posterior::as_draws_df(draws) - } - } else if (draws_format %in% c("draws_matrix", "matrix")) { - if (!posterior::is_draws_matrix(draws)) { - draws <- posterior::as_draws_matrix(draws) - } - } else if (draws_format %in% c("draws_list", "list")) { - if (!posterior::is_draws_list(draws)) { - draws <- posterior::as_draws_list(draws) - } - } +maybe_convert_draws_format <- function(draws, format) { + if (is.null(draws)) { + return(draws) } - draws + format <- sub("^draws_", "", format) + switch( + format, + "array" = posterior::as_draws_array(draws), + "df" = posterior::as_draws_df(draws), + "data.frame" = posterior::as_draws_df(draws), + "list" = posterior::as_draws_list(draws), + "matrix" = posterior::as_draws_matrix(draws), + stop("Invalid draws format.", call. = FALSE) + ) } - diff --git a/R/zzz.R b/R/zzz.R index 7897a8a46..4c1dfa327 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -27,7 +27,6 @@ startup_messages <- function() { } } - cmdstanr_initialize <- function() { # First check for environment variable CMDSTAN, but if not found # then see if default @@ -37,8 +36,11 @@ cmdstanr_initialize <- function() { path <- absolute_path(path) suppressMessages(set_cmdstan_path(path)) } else { - warning("Can't find directory specified by environment variable", - " 'CMDSTAN'. Path not set.", call. = FALSE) + warning( + "Can't find directory specified by environment variable 'CMDSTAN'. ", + "Path not set.", + call. = FALSE + ) .cmdstanr$PATH <- NULL } @@ -49,7 +51,7 @@ cmdstanr_initialize <- function() { } } - if (getRversion() < '3.5.0') { + if (getRversion() < "3.5.0") { .cmdstanr$TEMP_DIR <- tempdir() } else { .cmdstanr$TEMP_DIR <- tempdir(check = TRUE) diff --git a/man/cmdstan_default_path.Rd b/man/cmdstan_default_path.Rd index 98d1c515b..cfa4e0a59 100644 --- a/man/cmdstan_default_path.Rd +++ b/man/cmdstan_default_path.Rd @@ -7,10 +7,11 @@ cmdstan_default_path() } \value{ -Path to the cmdstan installation with the most recent release version, NULL if no -installation found. +Path to the CmdStan installation with the most recent release +version, or \code{NULL} if no installation found. } \description{ -Returns the path to the installation of cmdstan with the most recent release version. +Returns the path to the installation of CmdStan with the most recent release +version. } \keyword{internal} diff --git a/man/install_cmdstan.Rd b/man/install_cmdstan.Rd index cb7e98d6a..4d05aff29 100644 --- a/man/install_cmdstan.Rd +++ b/man/install_cmdstan.Rd @@ -105,6 +105,10 @@ Writing to the \code{make/local} file can be used to permanently add makefile flags/variables to an installation. For example adding specific compiler switches, changing the C++ compiler, etc. A change to the \code{make/local} file should typically be followed by calling \code{rebuild_cmdstan()}. + +The \code{check_cmdstan_toolchain()} function attempts to check for the required +C++ toolchain. It is called internally by \code{install_cmdstan()} but can also +be called directly by the user. } \examples{ \dontrun{ diff --git a/man/set_cmdstan_path.Rd b/man/set_cmdstan_path.Rd index 4b9a0f20e..683d5fc1c 100644 --- a/man/set_cmdstan_path.Rd +++ b/man/set_cmdstan_path.Rd @@ -43,8 +43,8 @@ this to avoid having to manually set the path every session: \item If the \link[=Sys.setenv]{environment variable} \code{"CMDSTAN"} exists at load time then its value will be automatically set as the default path to CmdStan for the \R session. -\item If no environment variable is found when loaded but any directory in the form -\code{".cmdstanr/cmdstan-[version]"}, for example \code{".cmdstanr/cmdstan-2.23.0"}, +\item If no environment variable is found when loaded but any directory in the +form \code{".cmdstanr/cmdstan-[version]"} (e.g., \code{".cmdstanr/cmdstan-2.23.0"}), exists in the user's home directory (\code{Sys.getenv("HOME")}, \emph{not} the current working directory) then the path to the cmdstan with the largest version number will be set as the path to CmdStan for the \R session. This is the diff --git a/man/write_stan_file.Rd b/man/write_stan_file.Rd index 820d50e6c..6955ec762 100644 --- a/man/write_stan_file.Rd +++ b/man/write_stan_file.Rd @@ -4,29 +4,46 @@ \alias{write_stan_file} \title{Write Stan code to a file} \usage{ -write_stan_file(code, dir = tempdir(), basename = NULL) +write_stan_file( + code, + dir = tempdir(), + basename = NULL, + force_overwrite = FALSE, + hash_salt = "" +) } \arguments{ -\item{code}{(multiple options) The Stan code: -\itemize{ -\item A single string containing a Stan program -\item A character vector containing the individual lines of a Stan program. -}} +\item{code}{(character vector) The Stan code to write to the file. This can +be a character vector of length one (a string) containing the entire Stan +program or a character vector with each element containing one line of the +Stan program.} \item{dir}{(string) An optional path to the directory where the file will be written. If omitted, a \link[base:tempfile]{temporary directory} is used by default.} \item{basename}{(string) If \code{dir} is specified, optionally the basename to -use for the file created. If not specified a file name is generated via -\code{\link[base:tempfile]{base::tempfile()}}.} +use for the file created. If not specified a file name is generated +from \link[rlang:hash]{hashing} the code.} + +\item{force_overwrite}{(logical) If set to \code{TRUE} the file will always be +overwritten and thus the resulting model will always be recompiled.} + +\item{hash_salt}{(string) Text to add to the model code prior to hashing to +determine the file name if \code{basename} is not set.} } \value{ The path to the file. } \description{ Convenience function for writing Stan code to a (possibly -\link[base:tempfile]{temporary}) file with a \code{.stan} extension. +\link[base:tempfile]{temporary}) file with a \code{.stan} extension. By default, the +file name is chosen deterministically based on a \link[rlang:hash]{hash} +of the Stan code, and the file is not overwritten if it already has correct +contents. This means that calling this function multiple times with the same +Stan code will reuse the compiled model. This also however means that the +function is potentially not thread-safe. Using \code{hash_salt = Sys.getpid()} +should ensure thread-safety in the rare cases when it is needed. } \examples{ # stan program as a single string diff --git a/man/write_stan_tempfile.Rd b/man/write_stan_tempfile.Rd index ba22a5b61..4434f1df2 100644 --- a/man/write_stan_tempfile.Rd +++ b/man/write_stan_tempfile.Rd @@ -7,11 +7,10 @@ write_stan_tempfile(code, dir = tempdir()) } \arguments{ -\item{code}{(multiple options) The Stan code: -\itemize{ -\item A single string containing a Stan program -\item A character vector containing the individual lines of a Stan program. -}} +\item{code}{(character vector) The Stan code to write to the file. This can +be a character vector of length one (a string) containing the entire Stan +program or a character vector with each element containing one line of the +Stan program.} \item{dir}{(string) An optional path to the directory where the file will be written. If omitted, a \link[base:tempfile]{temporary directory} is used by diff --git a/tests/testthat/test-csv.R b/tests/testthat/test-csv.R index 40a4604ab..2d6c0d4da 100644 --- a/tests/testthat/test-csv.R +++ b/tests/testthat/test-csv.R @@ -775,3 +775,29 @@ test_that("read_cmdstan_csv works with diagnose results", { expect_equal(diagnose_results$gradients$finite_diff, c(8.83081, 4.07931, -25.7167, -4.11423)) expect_equal(diagnose_results$gradients$error, c(9.919e-09, 3.13568e-08, -5.31186e-09, 5.87693e-09)) }) + +test_that("variable_dims() works", { + expect_null(variable_dims(NULL)) + + vars <- c("a", "b[1]", "b[2]", "b[3]", "c[1,1]", "c[1,2]") + vars_dims <- list(a = 1, b = 3, c = c(1,2)) + expect_equal(variable_dims(vars), vars_dims) + + vars <- c("a", "b") + vars_dims <- list(a = 1, b = 1) + expect_equal(variable_dims(vars), vars_dims) + + vars <- c("c[1,1]", "c[1,2]", "c[1,3]", "c[2,1]", "c[2,2]", "c[2,3]", "b[1]", "b[2]", "b[3]", "b[4]") + vars_dims <- list(c = c(2,3), b = 4) + expect_equal(variable_dims(vars), vars_dims) + + # make sure not confused by one name being last substring of another name + vars <- c("a[1]", "a[2]", "aa[1]", "aa[2]", "aa[3]") + expect_equal(variable_dims(vars), list(a = 2, aa = 3)) + + # wrong dimensions for descending order + vars <- c("c[1,1]", "c[1,2]", "c[1,3]", "c[2,3]", "c[2,2]", "c[2,1]", "b[4]", "b[2]", "b[3]", "b[1]") + vars_dims <- list(c = c(2,1), b = 1) + expect_equal(variable_dims(vars), vars_dims) +}) + diff --git a/tests/testthat/test-example.R b/tests/testthat/test-example.R index c7ca1f32a..84aa59280 100644 --- a/tests/testthat/test-example.R +++ b/tests/testthat/test-example.R @@ -67,6 +67,38 @@ test_that("write_stan_file creates dir if necessary", { ) }) +test_that("write_stan_file by default creates the same file for the same Stan model", { + dir <- file.path(test_path(), "answers") + + f1 <- write_stan_file(stan_program, dir = dir) + mtime1 <- file.info(f1)$mtime + + f2 <- write_stan_file(paste0(stan_program, "\n\n"), dir = dir) + expect_true(f1 != f2) + + # Test that writing the some model will not touch the file + # Wait a tiny bit to make sure the modified time will be different if + # overwrite happened + Sys.sleep(0.001) + f3 <- write_stan_file(stan_program, dir = dir) + expect_equal(f1, f3) + + mtime3 <- file.info(f3)$mtime + expect_equal(mtime1, mtime3) + + f4 <- write_stan_file(stan_program, dir = dir, hash_salt = "aaa") + expect_true(f1 != f4) + + f5 <- write_stan_file(stan_program, dir = dir, force_overwrite = TRUE) + expect_equal(f1, f5) + + mtime5 <- file.info(f5)$mtime + expect_true(mtime1 < mtime5) + + + try(file.remove(f1, f2, f4), silent = TRUE) +}) + test_that("write_stan_tempfile is deprecated", { expect_warning(write_stan_tempfile(stan_program), "deprecated") }) diff --git a/tests/testthat/test-model-compile.R b/tests/testthat/test-model-compile.R index 7a7c12a3c..663d00645 100644 --- a/tests/testthat/test-model-compile.R +++ b/tests/testthat/test-model-compile.R @@ -243,33 +243,33 @@ test_that("compiling stops on hyphens in stanc_options", { stan_file <- testing_stan_file("bernoulli") expect_error( cmdstan_model(stan_file, stanc_options = hyphens, compile = FALSE), - "No leading hyphens allowed in stanc options (--allow-undefined). Use options without leading hyphens, like for example `stanc_options = list('allow-undefined')`", + "No leading hyphens allowed in stanc options (--allow-undefined). Use options without leading hyphens, for example `stanc_options = list('allow-undefined')`", fixed = TRUE ) expect_error( cmdstan_model(stan_file, stanc_options = hyphens2, compile = FALSE), - "No leading hyphens allowed in stanc options (--allow-undefined). Use options without leading hyphens, like for example `stanc_options = list('allow-undefined')`", + "No leading hyphens allowed in stanc options (--allow-undefined). Use options without leading hyphens, for example `stanc_options = list('allow-undefined')`", fixed = TRUE ) expect_error( cmdstan_model(stan_file, stanc_options = hyphens3, compile = FALSE), - "No leading hyphens allowed in stanc options (--o). Use options without leading hyphens, like for example `stanc_options = list('allow-undefined')`", + "No leading hyphens allowed in stanc options (--o). Use options without leading hyphens, for example `stanc_options = list('allow-undefined')`", fixed = TRUE ) mod <- cmdstan_model(stan_file, compile = FALSE) expect_error( mod$compile(stanc_options = hyphens), - "No leading hyphens allowed in stanc options (--allow-undefined). Use options without leading hyphens, like for example `stanc_options = list('allow-undefined')`", + "No leading hyphens allowed in stanc options (--allow-undefined). Use options without leading hyphens, for example `stanc_options = list('allow-undefined')`", fixed = TRUE ) expect_error( mod$compile(stanc_options = hyphens2), - "No leading hyphens allowed in stanc options (--allow-undefined). Use options without leading hyphens, like for example `stanc_options = list('allow-undefined')`", + "No leading hyphens allowed in stanc options (--allow-undefined). Use options without leading hyphens, for example `stanc_options = list('allow-undefined')`", fixed = TRUE ) expect_error( mod$compile(stanc_options = hyphens3), - "No leading hyphens allowed in stanc options (--o). Use options without leading hyphens, like for example `stanc_options = list('allow-undefined')`", + "No leading hyphens allowed in stanc options (--o). Use options without leading hyphens, for example `stanc_options = list('allow-undefined')`", fixed = TRUE ) }) @@ -442,7 +442,22 @@ test_that("compiliation errors if folder with the model name exists", { } dir.create(exe) } - expect_error(cmdstan_model(stan_file), - "There is a subfolder matching the model name in the same folder as the model! Please remove or rename the subfolder and try again.") + expect_error( + cmdstan_model(stan_file), + "There is a subfolder matching the model name in the same folder as the model! Please remove or rename the subfolder and try again." + ) }) +test_that("cpp_options_to_compile_flags() works", { + options = list( + stan_threads = TRUE + ) + expect_equal(cpp_options_to_compile_flags(options), "STAN_THREADS=TRUE") + options = list( + stan_threads = TRUE, + stanc2 = TRUE + ) + expect_equal(cpp_options_to_compile_flags(options), c("STAN_THREADS=TRUE", "STANC2=TRUE")) + options = list() + expect_equal(cpp_options_to_compile_flags(options), NULL) +}) diff --git a/tests/testthat/test-model-init.R b/tests/testthat/test-model-init.R index 8ee1cfd26..4c0b89e41 100644 --- a/tests/testthat/test-model-init.R +++ b/tests/testthat/test-model-init.R @@ -166,7 +166,7 @@ test_that("error if init list is specified incorrectly", { init_list[[2]] = init_list[[1]] expect_error( mod_logistic$sample(data = data_list_logistic, chains = 2, init = init_list), - "'init' contains entries with parameter names that include square-brackets, which is not permitted. To supply inits for a vector, matrix or array of parameters, create a single entry with the parameter's name in the init list and specify init values for the entire parameter container." + "'init' contains entries with parameter names that include square-brackets, which is not permitted." ) }) diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R index 6ad351ebe..15d1b7d61 100644 --- a/tests/testthat/test-utils.R +++ b/tests/testthat/test-utils.R @@ -106,20 +106,6 @@ test_that("list_to_array fails for non-numeric values", { "All elements in list 'test-list' must be numeric!") }) -test_that("cpp_options_to_compile_flags() works", { - options = list( - stan_threads = TRUE - ) - expect_equal(cpp_options_to_compile_flags(options), "STAN_THREADS=TRUE") - options = list( - stan_threads = TRUE, - stanc2 = TRUE - ) - expect_equal(cpp_options_to_compile_flags(options), c("STAN_THREADS=TRUE", "STANC2=TRUE")) - options = list() - expect_equal(cpp_options_to_compile_flags(options), NULL) -}) - test_that("cmdstan_make_local() works", { exisiting_make_local <- cmdstan_make_local() make_local_path <- file.path(cmdstan_path(), "make", "local") @@ -153,31 +139,6 @@ test_that("cmdstan_make_local() works", { cmdstan_make_local(cpp_options = as.list(exisiting_make_local), append = FALSE) }) -test_that("variable_dims() works", { - expect_null(variable_dims(NULL)) - - vars <- c("a", "b[1]", "b[2]", "b[3]", "c[1,1]", "c[1,2]") - vars_dims <- list(a = 1, b = 3, c = c(1,2)) - expect_equal(variable_dims(vars), vars_dims) - - vars <- c("a", "b") - vars_dims <- list(a = 1, b = 1) - expect_equal(variable_dims(vars), vars_dims) - - vars <- c("c[1,1]", "c[1,2]", "c[1,3]", "c[2,1]", "c[2,2]", "c[2,3]", "b[1]", "b[2]", "b[3]", "b[4]") - vars_dims <- list(c = c(2,3), b = 4) - expect_equal(variable_dims(vars), vars_dims) - - # make sure not confused by one name being last substring of another name - vars <- c("a[1]", "a[2]", "aa[1]", "aa[2]", "aa[3]") - expect_equal(variable_dims(vars), list(a = 2, aa = 3)) - - # wrong dimensions for descending order - vars <- c("c[1,1]", "c[1,2]", "c[1,3]", "c[2,3]", "c[2,2]", "c[2,1]", "b[4]", "b[2]", "b[3]", "b[1]") - vars_dims <- list(c = c(2,1), b = 1) - expect_equal(variable_dims(vars), vars_dims) -}) - test_that("matching_variables() works", { ret <- matching_variables(c("beta"), c("alpha", "beta[1]", "beta[2]", "beta[3]")) expect_equal( From 635b66e049071f5c1e417c9a57ed96c005f3b691 Mon Sep 17 00:00:00 2001 From: "Jacob B. Socolar" Date: Thu, 20 May 2021 13:23:50 -0500 Subject: [PATCH 04/44] call var as stats::var Co-authored-by: Jonah Gabry --- R/utils.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/utils.R b/R/utils.R index 3cb2f64b5..f4746ad2c 100644 --- a/R/utils.R +++ b/R/utils.R @@ -284,7 +284,7 @@ check_bfmi <- function(post_warmup_sampler_diagnostics) { if (!is.null(post_warmup_sampler_diagnostics)) { energy <- posterior::extract_variable_matrix(post_warmup_sampler_diagnostics, "energy__") ebfmi <- apply(energy, 2, function(x) { - (sum(diff(x)^2)/length(x))/var(x) + (sum(diff(x)^2)/length(x))/stats::var(x) }) if (any(ebfmi < .3)) { message(sum(ebfmi < .3), " of ", length(ebfmi) , " chains had estimated Bayesian fraction From 9d0f895a76e733c0bdc6c37c6b1fcfcbf0fc1bc5 Mon Sep 17 00:00:00 2001 From: Jacob Socolar Date: Thu, 20 May 2021 14:00:19 -0500 Subject: [PATCH 05/44] better messages and error handling --- R/utils.R | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/R/utils.R b/R/utils.R index f4746ad2c..1aee3e38b 100644 --- a/R/utils.R +++ b/R/utils.R @@ -280,18 +280,26 @@ check_sampler_transitions_treedepth <- function(post_warmup_sampler_diagnostics, } } -check_bfmi <- function(post_warmup_sampler_diagnostics) { - if (!is.null(post_warmup_sampler_diagnostics)) { - energy <- posterior::extract_variable_matrix(post_warmup_sampler_diagnostics, "energy__") - ebfmi <- apply(energy, 2, function(x) { - (sum(diff(x)^2)/length(x))/stats::var(x) +check_ebfmi <- function(post_warmup_sampler_diagnostics, ebfmi_threshold = .3, return_ebfmi = F) { + pwsd <- posterior::as_draws_array(post_warmup_sampler_diagnostics) + if (! ("energy__" %in% dimnames(pwsd)$variable)) { + warning("e-bfmi not computed as the 'energy__' diagnostic could not be located") + } else if (dim(pwsd)[1] <= 1) { + warning("e-bfmi is undefined for posterior chains of length 1") + } else { + energy <- posterior::extract_variable_matrix(pwsd, "energy___") + fmi <- apply(energy, 2, function(x) { + (sum(diff(x)^2) / length(x)) / stats::var(x) }) - if (any(ebfmi < .3)) { - message(sum(ebfmi < .3), " of ", length(ebfmi) , " chains had estimated Bayesian fraction + if (any(fmi < ebfmi_threshold)) { + message(sum(fmi < ebfmi_threshold), " of ", length(fmi) , " chains had estimated Bayesian fraction of missing information(E-BFMI) less than 0.3, which may indicate poor exploration of the - posterior. Try to reparameterize the model.") + posterior. ") } } + if(return_ebfmi) { + fmi + } } From d5411b116b3a736a6e598d278fb0c131d49c0c56 Mon Sep 17 00:00:00 2001 From: Jacob Socolar Date: Thu, 20 May 2021 14:01:13 -0500 Subject: [PATCH 06/44] brought the return inside the if statement --- R/utils.R | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/R/utils.R b/R/utils.R index 1aee3e38b..11816454d 100644 --- a/R/utils.R +++ b/R/utils.R @@ -296,9 +296,9 @@ check_ebfmi <- function(post_warmup_sampler_diagnostics, ebfmi_threshold = .3, r of missing information(E-BFMI) less than 0.3, which may indicate poor exploration of the posterior. ") } - } - if(return_ebfmi) { - fmi + if(return_ebfmi) { + fmi + } } } From 7d7d6025f957b362843579ee6a4b46a81b1da628 Mon Sep 17 00:00:00 2001 From: Jacob Socolar Date: Thu, 20 May 2021 14:06:24 -0500 Subject: [PATCH 07/44] better message --- R/utils.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/utils.R b/R/utils.R index 11816454d..8645c28c0 100644 --- a/R/utils.R +++ b/R/utils.R @@ -293,7 +293,7 @@ check_ebfmi <- function(post_warmup_sampler_diagnostics, ebfmi_threshold = .3, r }) if (any(fmi < ebfmi_threshold)) { message(sum(fmi < ebfmi_threshold), " of ", length(fmi) , " chains had estimated Bayesian fraction - of missing information(E-BFMI) less than 0.3, which may indicate poor exploration of the + of missing information(E-BFMI) less than " threshold ", which may indicate poor exploration of the posterior. ") } if(return_ebfmi) { From 3050048de59638b355055a5cd7544c057ac907bc Mon Sep 17 00:00:00 2001 From: Jacob Socolar Date: Thu, 20 May 2021 14:55:06 -0500 Subject: [PATCH 08/44] fixed stupid error message typo --- R/utils.R | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/R/utils.R b/R/utils.R index 8645c28c0..a721fd291 100644 --- a/R/utils.R +++ b/R/utils.R @@ -282,19 +282,20 @@ check_sampler_transitions_treedepth <- function(post_warmup_sampler_diagnostics, check_ebfmi <- function(post_warmup_sampler_diagnostics, ebfmi_threshold = .3, return_ebfmi = F) { pwsd <- posterior::as_draws_array(post_warmup_sampler_diagnostics) - if (! ("energy__" %in% dimnames(pwsd)$variable)) { + if (!("energy__" %in% dimnames(pwsd)$variable)) { warning("e-bfmi not computed as the 'energy__' diagnostic could not be located") } else if (dim(pwsd)[1] <= 1) { warning("e-bfmi is undefined for posterior chains of length 1") } else { - energy <- posterior::extract_variable_matrix(pwsd, "energy___") + energy <- posterior::extract_variable_matrix(pwsd, "energy__") fmi <- apply(energy, 2, function(x) { (sum(diff(x)^2) / length(x)) / stats::var(x) - }) + } + ) if (any(fmi < ebfmi_threshold)) { - message(sum(fmi < ebfmi_threshold), " of ", length(fmi) , " chains had estimated Bayesian fraction - of missing information(E-BFMI) less than " threshold ", which may indicate poor exploration of the - posterior. ") + message(paste0(sum(fmi < ebfmi_threshold), " of ", length(fmi), " chains had estimated Bayesian fraction ", + "of missing information (E-BFMI) less than ", ebfmi_threshold, ", which may indicate poor exploration of the", + "posterior.")) } if(return_ebfmi) { fmi From fe16508dc572db55ac6520920a1e7f99756307eb Mon Sep 17 00:00:00 2001 From: Jacob Socolar Date: Thu, 20 May 2021 15:05:22 -0500 Subject: [PATCH 09/44] changed fmi to ebfmi everywhere for consistency --- R/utils.R | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/R/utils.R b/R/utils.R index a721fd291..fe571f3ee 100644 --- a/R/utils.R +++ b/R/utils.R @@ -288,17 +288,17 @@ check_ebfmi <- function(post_warmup_sampler_diagnostics, ebfmi_threshold = .3, r warning("e-bfmi is undefined for posterior chains of length 1") } else { energy <- posterior::extract_variable_matrix(pwsd, "energy__") - fmi <- apply(energy, 2, function(x) { + ebfmi <- apply(energy, 2, function(x) { (sum(diff(x)^2) / length(x)) / stats::var(x) } ) - if (any(fmi < ebfmi_threshold)) { - message(paste0(sum(fmi < ebfmi_threshold), " of ", length(fmi), " chains had estimated Bayesian fraction ", + if (any(ebfmi < ebfmi_threshold)) { + message(paste0(sum(ebfmi < ebfmi_threshold), " of ", length(ebfmi), " chains had estimated Bayesian fraction ", "of missing information (E-BFMI) less than ", ebfmi_threshold, ", which may indicate poor exploration of the", "posterior.")) } if(return_ebfmi) { - fmi + ebfmi } } } From 57264025c8d4f95567499f17db251f5f9d62cbd1 Mon Sep 17 00:00:00 2001 From: Jacob Socolar Date: Thu, 20 May 2021 15:06:03 -0500 Subject: [PATCH 10/44] formatting for consistency --- R/utils.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/utils.R b/R/utils.R index fe571f3ee..350e13db7 100644 --- a/R/utils.R +++ b/R/utils.R @@ -297,7 +297,7 @@ check_ebfmi <- function(post_warmup_sampler_diagnostics, ebfmi_threshold = .3, r "of missing information (E-BFMI) less than ", ebfmi_threshold, ", which may indicate poor exploration of the", "posterior.")) } - if(return_ebfmi) { + if (return_ebfmi) { ebfmi } } From 039fc28cc842930ecb38f9e8640ad2be414b243d Mon Sep 17 00:00:00 2001 From: Jacob Socolar Date: Thu, 20 May 2021 16:11:18 -0500 Subject: [PATCH 11/44] added tests --- R/utils.R | 10 +++++----- tests/testthat/test-utils.R | 14 ++++++++++++++ 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/R/utils.R b/R/utils.R index 350e13db7..1648d86ce 100644 --- a/R/utils.R +++ b/R/utils.R @@ -283,9 +283,9 @@ check_sampler_transitions_treedepth <- function(post_warmup_sampler_diagnostics, check_ebfmi <- function(post_warmup_sampler_diagnostics, ebfmi_threshold = .3, return_ebfmi = F) { pwsd <- posterior::as_draws_array(post_warmup_sampler_diagnostics) if (!("energy__" %in% dimnames(pwsd)$variable)) { - warning("e-bfmi not computed as the 'energy__' diagnostic could not be located") + warning("E-BFMI not computed as the 'energy__' diagnostic could not be located") } else if (dim(pwsd)[1] <= 1) { - warning("e-bfmi is undefined for posterior chains of length 1") + warning("E-BFMI is undefined for posterior chains of length 1") } else { energy <- posterior::extract_variable_matrix(pwsd, "energy__") ebfmi <- apply(energy, 2, function(x) { @@ -293,9 +293,9 @@ check_ebfmi <- function(post_warmup_sampler_diagnostics, ebfmi_threshold = .3, r } ) if (any(ebfmi < ebfmi_threshold)) { - message(paste0(sum(ebfmi < ebfmi_threshold), " of ", length(ebfmi), " chains had estimated Bayesian fraction ", - "of missing information (E-BFMI) less than ", ebfmi_threshold, ", which may indicate poor exploration of the", - "posterior.")) + message(paste0(sum(ebfmi < ebfmi_threshold), " of ", length(ebfmi), " chains had energy-based Bayesian fraction ", + "of missing information (E-BFMI) less than ", ebfmi_threshold, ", which may indicate poor exploration of the ", + "posterior")) } if (return_ebfmi) { ebfmi diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R index 15d1b7d61..83428980f 100644 --- a/tests/testthat/test-utils.R +++ b/tests/testthat/test-utils.R @@ -171,3 +171,17 @@ test_that("matching_variables() works", { ) expect_equal(length(ret$not_found), 0) }) + +test_that("check_ebfmi works", { + set.seed(1) + energy_df <- data.frame("energy__" = rnorm(1000)) + expect_error(check_ebfmi(energy_df), NA) + energy_df[1] <- 0 + for(i in 1:999){ + energy_df$energy__[i+1] <- energy_df$energy__[i] + rnorm(1, 0, 0.01) + } + expect_message(check_ebfmi(energy_df), "fraction of missing information \\(E-BFMI\\) less than") + energy_vec <- energy_df$energy__ + expect_equal(suppressMessages(check_ebfmi(energy_df, return_ebfmi = TRUE)), + (sum(diff(energy_vec)^2) / length(energy_vec)) / stats::var(energy_vec)) +}) From 9634ad797b2afad73106ba3f273be40eadce4009 Mon Sep 17 00:00:00 2001 From: Jacob Socolar Date: Thu, 20 May 2021 16:49:20 -0500 Subject: [PATCH 12/44] better error messages --- R/utils.R | 21 ++++++++++++++++++--- tests/testthat/test-utils.R | 3 +++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/R/utils.R b/R/utils.R index ecdb7bf91..6dd66e127 100644 --- a/R/utils.R +++ b/R/utils.R @@ -280,14 +280,29 @@ check_sampler_transitions_treedepth <- function(post_warmup_sampler_diagnostics, } } -check_ebfmi <- function(post_warmup_sampler_diagnostics, ebfmi_threshold = .3, return_ebfmi = F) { +check_ebfmi <- function(post_warmup_sampler_diagnostics, ebfmi_threshold = .2, return_ebfmi = F) { pwsd <- posterior::as_draws_array(post_warmup_sampler_diagnostics) if (!("energy__" %in% dimnames(pwsd)$variable)) { - warning("E-BFMI not computed as the 'energy__' diagnostic could not be located") + if (! return_ebfmi) { + warning("E-BFMI not computed as the 'energy__' diagnostic could not be located") + } else { + stop("E-BFMI not computed as the 'energy__' diagnostic could not be located") + } } else if (dim(pwsd)[1] <= 1) { - warning("E-BFMI is undefined for posterior chains of length 1") + if (! return_ebfmi) { + warning("E-BFMI is undefined for posterior chains of length less than 2") + } else { + stop("E-BFMI is undefined for posterior chains of length less than 2") + } } else { energy <- posterior::extract_variable_matrix(pwsd, "energy__") + if (any is.na(energy)) { + if (! return_ebfmi) { + warning("E-BFMI not computed 'energy__' contains NAs") + } else { + stop("E-BFMI not computed 'energy__' contains NAs") + } + } ebfmi <- apply(energy, 2, function(x) { (sum(diff(x)^2) / length(x)) / stats::var(x) } diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R index 8afc66000..07e04f67f 100644 --- a/tests/testthat/test-utils.R +++ b/tests/testthat/test-utils.R @@ -184,6 +184,9 @@ test_that("check_ebfmi works", { energy_vec <- energy_df$energy__ expect_equal(suppressMessages(check_ebfmi(energy_df, return_ebfmi = TRUE)), (sum(diff(energy_vec)^2) / length(energy_vec)) / stats::var(energy_vec)) + energy_df <- data.frame("energy__" = 0) + expect_error(check_ebfmi(energy_df, return_ebfmi = TRUE), "E-BFMI is undefined for posterior chains of length 1") + expect_warning(check_ebfmi(energy_df), "E-BFMI is undefined for posterior chains of length 1") }) test_that("require_suggested_package() works", { From d1ba154b40b1413386c8db18832f96a5ffbc0ed1 Mon Sep 17 00:00:00 2001 From: Jacob Socolar Date: Thu, 20 May 2021 16:58:35 -0500 Subject: [PATCH 13/44] cleaning up typos from prev commit --- R/utils.R | 6 +++--- tests/testthat/test-utils.R | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/R/utils.R b/R/utils.R index 6dd66e127..73958cbb4 100644 --- a/R/utils.R +++ b/R/utils.R @@ -296,11 +296,11 @@ check_ebfmi <- function(post_warmup_sampler_diagnostics, ebfmi_threshold = .2, r } } else { energy <- posterior::extract_variable_matrix(pwsd, "energy__") - if (any is.na(energy)) { + if (any(is.na(energy))) { if (! return_ebfmi) { - warning("E-BFMI not computed 'energy__' contains NAs") + warning("E-BFMI not computed because 'energy__' contains NAs") } else { - stop("E-BFMI not computed 'energy__' contains NAs") + stop("E-BFMI not computed because 'energy__' contains NAs") } } ebfmi <- apply(energy, 2, function(x) { diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R index 07e04f67f..4239ad8a9 100644 --- a/tests/testthat/test-utils.R +++ b/tests/testthat/test-utils.R @@ -185,8 +185,8 @@ test_that("check_ebfmi works", { expect_equal(suppressMessages(check_ebfmi(energy_df, return_ebfmi = TRUE)), (sum(diff(energy_vec)^2) / length(energy_vec)) / stats::var(energy_vec)) energy_df <- data.frame("energy__" = 0) - expect_error(check_ebfmi(energy_df, return_ebfmi = TRUE), "E-BFMI is undefined for posterior chains of length 1") - expect_warning(check_ebfmi(energy_df), "E-BFMI is undefined for posterior chains of length 1") + expect_error(check_ebfmi(energy_df, return_ebfmi = TRUE), "E-BFMI is undefined for posterior chains of length less than 2") + expect_warning(check_ebfmi(energy_df), "E-BFMI is undefined for posterior chains of length less than 2") }) test_that("require_suggested_package() works", { From 1c5cb6204882cfffcd096a853f6c8f94e315aa20 Mon Sep 17 00:00:00 2001 From: Jacob Socolar Date: Thu, 20 May 2021 17:12:43 -0500 Subject: [PATCH 14/44] realized the ebfmi is meaningless for chains of length 2 --- R/utils.R | 6 +++--- tests/testthat/test-utils.R | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/R/utils.R b/R/utils.R index 73958cbb4..2703eef7c 100644 --- a/R/utils.R +++ b/R/utils.R @@ -288,11 +288,11 @@ check_ebfmi <- function(post_warmup_sampler_diagnostics, ebfmi_threshold = .2, r } else { stop("E-BFMI not computed as the 'energy__' diagnostic could not be located") } - } else if (dim(pwsd)[1] <= 1) { + } else if (dim(pwsd)[1] <= 2) { if (! return_ebfmi) { - warning("E-BFMI is undefined for posterior chains of length less than 2") + warning("E-BFMI is undefined for posterior chains of length less than 3") } else { - stop("E-BFMI is undefined for posterior chains of length less than 2") + stop("E-BFMI is undefined for posterior chains of length less than 3") } } else { energy <- posterior::extract_variable_matrix(pwsd, "energy__") diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R index 4239ad8a9..10870c423 100644 --- a/tests/testthat/test-utils.R +++ b/tests/testthat/test-utils.R @@ -185,8 +185,8 @@ test_that("check_ebfmi works", { expect_equal(suppressMessages(check_ebfmi(energy_df, return_ebfmi = TRUE)), (sum(diff(energy_vec)^2) / length(energy_vec)) / stats::var(energy_vec)) energy_df <- data.frame("energy__" = 0) - expect_error(check_ebfmi(energy_df, return_ebfmi = TRUE), "E-BFMI is undefined for posterior chains of length less than 2") - expect_warning(check_ebfmi(energy_df), "E-BFMI is undefined for posterior chains of length less than 2") + expect_error(check_ebfmi(energy_df, return_ebfmi = TRUE), "E-BFMI is undefined for posterior chains of length less than 3") + expect_warning(check_ebfmi(energy_df), "E-BFMI is undefined for posterior chains of length less than 3") }) test_that("require_suggested_package() works", { From b590c0a2bbd084fc66def7ea017c0073e76e67be Mon Sep 17 00:00:00 2001 From: Jacob Socolar Date: Fri, 21 May 2021 10:02:12 -0500 Subject: [PATCH 15/44] change check_bfmi to check_ebfmi --- R/fit.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/fit.R b/R/fit.R index ce5aa603f..e9d1718ac 100644 --- a/R/fit.R +++ b/R/fit.R @@ -833,7 +833,7 @@ CmdStanMCMC <- R6::R6Class( if (!fixed_param) { check_divergences(private$sampler_diagnostics_) check_sampler_transitions_treedepth(private$sampler_diagnostics_, private$metadata_) - check_bfmi(private$sampler_diagnostics_) + check_ebfmi(private$sampler_diagnostics_) } } } From 95e9e1f90248c7e165c7eb67b4fdf5fdb3ada386 Mon Sep 17 00:00:00 2001 From: Jacob Socolar Date: Fri, 21 May 2021 10:05:09 -0500 Subject: [PATCH 16/44] change check_bfmi to check_ebfmi --- R/csv.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/csv.R b/R/csv.R index 264886e0a..419970a80 100644 --- a/R/csv.R +++ b/R/csv.R @@ -438,7 +438,7 @@ CmdStanMCMC_CSV <- R6::R6Class( if (check_diagnostics) { check_divergences(csv_contents$post_warmup_sampler_diagnostics) check_sampler_transitions_treedepth(csv_contents$post_warmup_sampler_diagnostics, csv_contents$metadata) - check_bfmi(csv_contents$post_warmup_sampler_diagnostics) + check_ebfmi(csv_contents$post_warmup_sampler_diagnostics) } private$output_files_ <- files private$metadata_ <- csv_contents$metadata From 5e8e2503d91a7036526429c7834f2bb12e0ffeab Mon Sep 17 00:00:00 2001 From: rok-cesnovar Date: Tue, 25 May 2021 12:25:38 +0200 Subject: [PATCH 17/44] fix testing issues --- R/utils.R | 64 ++++++++++++++++++---------------- tests/testthat/test-fit-mcmc.R | 2 +- 2 files changed, 34 insertions(+), 32 deletions(-) diff --git a/R/utils.R b/R/utils.R index 2703eef7c..26282a567 100644 --- a/R/utils.R +++ b/R/utils.R @@ -281,41 +281,43 @@ check_sampler_transitions_treedepth <- function(post_warmup_sampler_diagnostics, } check_ebfmi <- function(post_warmup_sampler_diagnostics, ebfmi_threshold = .2, return_ebfmi = F) { - pwsd <- posterior::as_draws_array(post_warmup_sampler_diagnostics) - if (!("energy__" %in% dimnames(pwsd)$variable)) { - if (! return_ebfmi) { - warning("E-BFMI not computed as the 'energy__' diagnostic could not be located") - } else { - stop("E-BFMI not computed as the 'energy__' diagnostic could not be located") - } - } else if (dim(pwsd)[1] <= 2) { - if (! return_ebfmi) { - warning("E-BFMI is undefined for posterior chains of length less than 3") - } else { - stop("E-BFMI is undefined for posterior chains of length less than 3") - } - } else { - energy <- posterior::extract_variable_matrix(pwsd, "energy__") - if (any(is.na(energy))) { + if (!is.null(post_warmup_sampler_diagnostics) && posterior::niterations(post_warmup_sampler_diagnostics) > 0) { + pwsd <- posterior::as_draws_array(post_warmup_sampler_diagnostics) + if (!("energy__" %in% dimnames(pwsd)$variable)) { if (! return_ebfmi) { - warning("E-BFMI not computed because 'energy__' contains NAs") + warning("E-BFMI not computed as the 'energy__' diagnostic could not be located") } else { - stop("E-BFMI not computed because 'energy__' contains NAs") + stop("E-BFMI not computed as the 'energy__' diagnostic could not be located") + } + } else if (dim(pwsd)[1] <= 2) { + if (! return_ebfmi) { + warning("E-BFMI is undefined for posterior chains of length less than 3") + } else { + stop("E-BFMI is undefined for posterior chains of length less than 3") + } + } else { + energy <- posterior::extract_variable_matrix(pwsd, "energy__") + if (any(is.na(energy))) { + if (! return_ebfmi) { + warning("E-BFMI not computed because 'energy__' contains NAs") + } else { + stop("E-BFMI not computed because 'energy__' contains NAs") + } + } + ebfmi <- apply(energy, 2, function(x) { + (sum(diff(x)^2) / length(x)) / stats::var(x) + } + ) + if (any(ebfmi < ebfmi_threshold)) { + message(paste0(sum(ebfmi < ebfmi_threshold), " of ", length(ebfmi), " chains had energy-based Bayesian fraction ", + "of missing information (E-BFMI) less than ", ebfmi_threshold, ", which may indicate poor exploration of the ", + "posterior")) + } + if (return_ebfmi) { + ebfmi } } - ebfmi <- apply(energy, 2, function(x) { - (sum(diff(x)^2) / length(x)) / stats::var(x) - } - ) - if (any(ebfmi < ebfmi_threshold)) { - message(paste0(sum(ebfmi < ebfmi_threshold), " of ", length(ebfmi), " chains had energy-based Bayesian fraction ", - "of missing information (E-BFMI) less than ", ebfmi_threshold, ", which may indicate poor exploration of the ", - "posterior")) - } - if (return_ebfmi) { - ebfmi - } - } + } } diff --git a/tests/testthat/test-fit-mcmc.R b/tests/testthat/test-fit-mcmc.R index edb3bd4ad..1696bad33 100644 --- a/tests/testthat/test-fit-mcmc.R +++ b/tests/testthat/test-fit-mcmc.R @@ -50,7 +50,7 @@ test_that("draws() works when gradually adding variables", { expect_equal(posterior::variables(draws_lp__), c("lp__")) expect_type(sampler_diagnostics, "double") expect_s3_class(sampler_diagnostics, "draws_array") - expect_equal(posterior::variables(sampler_diagnostics), c(c("treedepth__", "divergent__", "accept_stat__", "stepsize__", "n_leapfrog__", "energy__"))) + expect_equal(posterior::variables(sampler_diagnostics), c(c("treedepth__", "divergent__", "energy__", "accept_stat__", "stepsize__", "n_leapfrog__"))) draws_alpha <- fit$draws(variables = c("alpha"), inc_warmup = TRUE) expect_type(draws_alpha, "double") expect_s3_class(draws_alpha, "draws_array") From e130563051f8d112135be5c9de3abdbd6ffe27c8 Mon Sep 17 00:00:00 2001 From: rok-cesnovar Date: Tue, 25 May 2021 13:15:43 +0200 Subject: [PATCH 18/44] remove niterations call --- R/utils.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/utils.R b/R/utils.R index 26282a567..34d9eff6c 100644 --- a/R/utils.R +++ b/R/utils.R @@ -281,7 +281,7 @@ check_sampler_transitions_treedepth <- function(post_warmup_sampler_diagnostics, } check_ebfmi <- function(post_warmup_sampler_diagnostics, ebfmi_threshold = .2, return_ebfmi = F) { - if (!is.null(post_warmup_sampler_diagnostics) && posterior::niterations(post_warmup_sampler_diagnostics) > 0) { + if (!is.null(post_warmup_sampler_diagnostics)) { pwsd <- posterior::as_draws_array(post_warmup_sampler_diagnostics) if (!("energy__" %in% dimnames(pwsd)$variable)) { if (! return_ebfmi) { From b1c43a1c8cb05905b49aa569a1d73efb0d60f39b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rok=20=C4=8Ce=C5=A1novar?= Date: Mon, 19 Jul 2021 12:19:07 +0200 Subject: [PATCH 19/44] Apply suggestions from code review Co-authored-by: Jonah Gabry --- R/utils.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/utils.R b/R/utils.R index 34d9eff6c..7a21e90a2 100644 --- a/R/utils.R +++ b/R/utils.R @@ -280,11 +280,11 @@ check_sampler_transitions_treedepth <- function(post_warmup_sampler_diagnostics, } } -check_ebfmi <- function(post_warmup_sampler_diagnostics, ebfmi_threshold = .2, return_ebfmi = F) { +check_ebfmi <- function(post_warmup_sampler_diagnostics, threshold = 0.2, return_ebfmi = FALSE) { if (!is.null(post_warmup_sampler_diagnostics)) { pwsd <- posterior::as_draws_array(post_warmup_sampler_diagnostics) if (!("energy__" %in% dimnames(pwsd)$variable)) { - if (! return_ebfmi) { + if (!return_ebfmi) { warning("E-BFMI not computed as the 'energy__' diagnostic could not be located") } else { stop("E-BFMI not computed as the 'energy__' diagnostic could not be located") From 1a95d7662cca3f3ebe56ceb7dea9ebbb2e1a3af7 Mon Sep 17 00:00:00 2001 From: Rok Cesnovar Date: Mon, 19 Jul 2021 13:57:30 +0200 Subject: [PATCH 20/44] cleanup --- R/utils.R | 50 +++++++++++++------------------------ tests/testthat/test-utils.R | 13 ++++++---- 2 files changed, 26 insertions(+), 37 deletions(-) diff --git a/R/utils.R b/R/utils.R index 7a21e90a2..91e205634 100644 --- a/R/utils.R +++ b/R/utils.R @@ -280,44 +280,30 @@ check_sampler_transitions_treedepth <- function(post_warmup_sampler_diagnostics, } } -check_ebfmi <- function(post_warmup_sampler_diagnostics, threshold = 0.2, return_ebfmi = FALSE) { +check_ebfmi <- function(post_warmup_sampler_diagnostics, threshold = 0.2) { + ebfmi <- NULL if (!is.null(post_warmup_sampler_diagnostics)) { - pwsd <- posterior::as_draws_array(post_warmup_sampler_diagnostics) - if (!("energy__" %in% dimnames(pwsd)$variable)) { - if (!return_ebfmi) { - warning("E-BFMI not computed as the 'energy__' diagnostic could not be located") - } else { - stop("E-BFMI not computed as the 'energy__' diagnostic could not be located") - } - } else if (dim(pwsd)[1] <= 2) { - if (! return_ebfmi) { - warning("E-BFMI is undefined for posterior chains of length less than 3") - } else { - stop("E-BFMI is undefined for posterior chains of length less than 3") - } + if (!("energy__" %in% posterior::variables(post_warmup_sampler_diagnostics))) { + warning("E-BFMI not computed as the 'energy__' diagnostic could not be located.", call. = FALSE) + } else if (posterior::niterations(post_warmup_sampler_diagnostics) < 3) { + warning("E-BFMI is undefined for posterior chains of length less than 3.", call. = FALSE) } else { - energy <- posterior::extract_variable_matrix(pwsd, "energy__") + energy <- posterior::extract_variable_matrix(post_warmup_sampler_diagnostics, "energy__") if (any(is.na(energy))) { - if (! return_ebfmi) { - warning("E-BFMI not computed because 'energy__' contains NAs") - } else { - stop("E-BFMI not computed because 'energy__' contains NAs") + warning("E-BFMI not computed because 'energy__' contains NAs.", call. = FALSE) + } else { + ebfmi <- apply(energy, 2, function(x) { + (sum(diff(x)^2) / length(x)) / stats::var(x) + }) + if (!is.null(threshold) && any(ebfmi < threshold)) { + message(paste0(sum(ebfmi < threshold), " of ", length(ebfmi), " chains had energy-based Bayesian fraction ", + "of missing information (E-BFMI) less than ", threshold, ", which may indicate poor exploration of the ", + "posterior.")) } } - ebfmi <- apply(energy, 2, function(x) { - (sum(diff(x)^2) / length(x)) / stats::var(x) - } - ) - if (any(ebfmi < ebfmi_threshold)) { - message(paste0(sum(ebfmi < ebfmi_threshold), " of ", length(ebfmi), " chains had energy-based Bayesian fraction ", - "of missing information (E-BFMI) less than ", ebfmi_threshold, ", which may indicate poor exploration of the ", - "posterior")) - } - if (return_ebfmi) { - ebfmi - } } - } + } + ebfmi } diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R index 10870c423..09e5e4ce4 100644 --- a/tests/testthat/test-utils.R +++ b/tests/testthat/test-utils.R @@ -175,17 +175,20 @@ test_that("matching_variables() works", { test_that("check_ebfmi works", { set.seed(1) energy_df <- data.frame("energy__" = rnorm(1000)) - expect_error(check_ebfmi(energy_df), NA) + expect_error(suppressWarnings(check_ebfmi(posterior::as_draws(energy_df))), NA) energy_df[1] <- 0 for(i in 1:999){ energy_df$energy__[i+1] <- energy_df$energy__[i] + rnorm(1, 0, 0.01) } + energy_df <- posterior::as_draws(energy_df) expect_message(check_ebfmi(energy_df), "fraction of missing information \\(E-BFMI\\) less than") energy_vec <- energy_df$energy__ - expect_equal(suppressMessages(check_ebfmi(energy_df, return_ebfmi = TRUE)), - (sum(diff(energy_vec)^2) / length(energy_vec)) / stats::var(energy_vec)) - energy_df <- data.frame("energy__" = 0) - expect_error(check_ebfmi(energy_df, return_ebfmi = TRUE), "E-BFMI is undefined for posterior chains of length less than 3") + check_val <- (sum(diff(energy_vec)^2) / length(energy_vec)) / stats::var(energy_vec) + expect_equal(as.numeric(check_ebfmi(energy_df)), check_val) + expect_equal(as.numeric(check_ebfmi(posterior::as_draws_array(energy_df))), check_val) + expect_equal(as.numeric(check_ebfmi(posterior::as_draws_list(energy_df))), check_val) + expect_equal(as.numeric(check_ebfmi(posterior::as_draws_matrix(energy_df))), check_val) + energy_df <- posterior::as_draws(data.frame("energy__" = 0)) expect_warning(check_ebfmi(energy_df), "E-BFMI is undefined for posterior chains of length less than 3") }) From 4079b59bdb4ed1b1296bca9a4f505e0cea218585 Mon Sep 17 00:00:00 2001 From: Rok Cesnovar Date: Mon, 19 Jul 2021 14:47:19 +0200 Subject: [PATCH 21/44] pull ebfmi compute out, fix tests --- R/utils.R | 23 ++- tests/testthat/test-utils.R | 363 ++++++++++++++++++------------------ 2 files changed, 199 insertions(+), 187 deletions(-) diff --git a/R/utils.R b/R/utils.R index 91e205634..e8fe4b6d6 100644 --- a/R/utils.R +++ b/R/utils.R @@ -280,8 +280,8 @@ check_sampler_transitions_treedepth <- function(post_warmup_sampler_diagnostics, } } -check_ebfmi <- function(post_warmup_sampler_diagnostics, threshold = 0.2) { - ebfmi <- NULL +ebfmi <- function(post_warmup_sampler_diagnostics) { + efbmi_val <- NULL if (!is.null(post_warmup_sampler_diagnostics)) { if (!("energy__" %in% posterior::variables(post_warmup_sampler_diagnostics))) { warning("E-BFMI not computed as the 'energy__' diagnostic could not be located.", call. = FALSE) @@ -292,18 +292,23 @@ check_ebfmi <- function(post_warmup_sampler_diagnostics, threshold = 0.2) { if (any(is.na(energy))) { warning("E-BFMI not computed because 'energy__' contains NAs.", call. = FALSE) } else { - ebfmi <- apply(energy, 2, function(x) { + efbmi_val <- apply(energy, 2, function(x) { (sum(diff(x)^2) / length(x)) / stats::var(x) }) - if (!is.null(threshold) && any(ebfmi < threshold)) { - message(paste0(sum(ebfmi < threshold), " of ", length(ebfmi), " chains had energy-based Bayesian fraction ", - "of missing information (E-BFMI) less than ", threshold, ", which may indicate poor exploration of the ", - "posterior.")) - } } } } - ebfmi + efbmi_val +} + +check_ebfmi <- function(post_warmup_sampler_diagnostics, threshold = 0.2) { + efbmi_val <- ebfmi(post_warmup_sampler_diagnostics) + if (any(efbmi_val < threshold)) { + message(paste0(sum(efbmi_val < threshold), " of ", length(efbmi_val), " chains had energy-based Bayesian fraction ", + "of missing information (E-BFMI) less than ", threshold, ", which may indicate poor exploration of the ", + "posterior.")) + } + invisible(NULL) } diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R index 09e5e4ce4..197f01cd3 100644 --- a/tests/testthat/test-utils.R +++ b/tests/testthat/test-utils.R @@ -1,181 +1,182 @@ context("utils") -if (not_on_cran()) { - set_cmdstan_path() - fit_mcmc <- testing_fit("logistic", method = "sample", - seed = 123, chains = 2) -} - -test_that("check_divergences() works", { - skip_on_cran() - csv_files <- c(test_path("resources", "csv", "model1-2-no-warmup.csv")) - csv_output <- read_cmdstan_csv(csv_files) - output <- "14 of 100 \\(14.0%\\) transitions ended with a divergence." - expect_message(check_divergences(csv_output$post_warmup_sampler_diagnostics), output) - - csv_files <- c(test_path("resources", "csv", "model1-2-no-warmup.csv"), - test_path("resources", "csv", "model1-2-no-warmup.csv")) - csv_output <- read_cmdstan_csv(csv_files) - output <- "28 of 200 \\(14.0%\\) transitions ended with a divergence." - expect_message(check_divergences(csv_output$post_warmup_sampler_diagnostics), output) - - csv_files <- c(test_path("resources", "csv", "model1-2-warmup.csv")) - csv_output <- read_cmdstan_csv(csv_files) - output <- "1 of 100 \\(1.0%\\) transitions ended with a divergence." - expect_message(check_divergences(csv_output$post_warmup_sampler_diagnostics), output) - - - fit_wramup_no_samples <- testing_fit("logistic", method = "sample", - seed = 123, chains = 1, - iter_sampling = 0, - iter_warmup = 10, - save_warmup = TRUE, - validate_csv = FALSE) - csv_output <- read_cmdstan_csv(fit_wramup_no_samples$output_files()) - expect_message(check_divergences(csv_output$post_warmup_sampler_diagnostics), regexp = NA) -}) - -test_that("check_sampler_transitions_treedepth() works", { - skip_on_cran() - csv_files <- c(test_path("resources", "csv", "model1-2-no-warmup.csv")) - csv_output <- read_cmdstan_csv(csv_files) - output <- "16 of 100 \\(16.0%\\) transitions hit the maximum treedepth limit of 5 or 2\\^5-1 leapfrog steps." - expect_message( - check_sampler_transitions_treedepth( - csv_output$post_warmup_sampler_diagnostics, - csv_output$metadata), - output - ) - - csv_files <- c(test_path("resources", "csv", "model1-2-no-warmup.csv"), - test_path("resources", "csv", "model1-2-no-warmup.csv")) - csv_output <- read_cmdstan_csv(csv_files) - output <- "32 of 200 \\(16.0%\\) transitions hit the maximum treedepth limit of 5 or 2\\^5-1 leapfrog steps." - expect_message( - check_sampler_transitions_treedepth( - csv_output$post_warmup_sampler_diagnostics, - csv_output$metadata), - output - ) - - csv_files <- c(test_path("resources", "csv", "model1-2-warmup.csv")) - csv_output <- read_cmdstan_csv(csv_files) - output <- "1 of 100 \\(1.0%\\) transitions hit the maximum treedepth limit of 5 or 2\\^5-1 leapfrog steps." - expect_message( - check_sampler_transitions_treedepth( - csv_output$post_warmup_sampler_diagnostics, - csv_output$metadata), - output - ) -}) - -test_that("cmdstan_summary works if bin/stansummary deleted file", { - skip_on_cran() - delete_and_run <- function() { - file.remove(file.path(cmdstan_path(), "bin", cmdstan_ext("stansummary"))) - fit_mcmc$cmdstan_summary() - } - expect_output(delete_and_run(), "Inference for Stan model: logistic_model\\n2 chains: each with iter") -}) - -test_that("cmdstan_diagnose works if bin/diagnose deleted file", { - skip_on_cran() - delete_and_run <- function() { - file.remove(file.path(cmdstan_path(), "bin", cmdstan_ext("diagnose"))) - fit_mcmc$cmdstan_diagnose() - } - expect_output(delete_and_run(), "Checking sampler transitions treedepth") -}) - -test_that("repair_path() fixes slashes", { - # all slashes should be single "/", and no trailing slash - expect_equal(repair_path("a//b\\c/"), "a/b/c") -}) - -test_that("repair_path works with zero length path or non-string path", { - expect_equal(repair_path(""), "") - expect_equal(repair_path(5), 5) -}) - -test_that("list_to_array works with empty list", { - expect_equal(list_to_array(list()), NULL) -}) - -test_that("list_to_array fails for non-numeric values", { - expect_error(list_to_array(list(k = "test"), name = "test-list"), - "All elements in list 'test-list' must be numeric!") -}) - -test_that("cmdstan_make_local() works", { - exisiting_make_local <- cmdstan_make_local() - make_local_path <- file.path(cmdstan_path(), "make", "local") - if (file.exists(make_local_path)) { - file.remove(make_local_path) - } - expect_equal(cmdstan_make_local(), NULL) - cpp_options = list( - "CXX" = "clang++", - "CXXFLAGS+= -march-native", - TEST1 = TRUE, - "TEST2" = FALSE - ) - expect_equal(cmdstan_make_local(cpp_options = cpp_options), - c( - "CXX=clang++", - "CXXFLAGS+= -march-native", - "TEST1=true", - "TEST2=false" - )) - expect_equal(cmdstan_make_local(cpp_options = list("TEST3" = TRUE)), - c( - "CXX=clang++", - "CXXFLAGS+= -march-native", - "TEST1=true", - "TEST2=false", - "TEST3=true" - )) - expect_equal(cmdstan_make_local(cpp_options = list("TEST4" = TRUE), append = FALSE), - c("TEST4=true")) - cmdstan_make_local(cpp_options = as.list(exisiting_make_local), append = FALSE) -}) - -test_that("matching_variables() works", { - ret <- matching_variables(c("beta"), c("alpha", "beta[1]", "beta[2]", "beta[3]")) - expect_equal( - ret$matching, - c("beta[1]", "beta[2]", "beta[3]") - ) - expect_equal(length(ret$not_found), 0) - - ret <- matching_variables(c("alpha"), c("alpha", "beta[1]", "beta[2]", "beta[3]")) - expect_equal( - ret$matching, - c("alpha") - ) - expect_equal(length(ret$not_found), 0) - - ret <- matching_variables(c("alpha", "theta"), c("alpha", "beta[1]", "beta[2]", "beta[3]")) - expect_equal( - ret$matching, - c("alpha") - ) - expect_equal( - ret$not_found, - c("theta") - ) - - ret <- matching_variables(c("alpha", "beta"), c("alpha", "beta[1]", "beta[2]", "beta[3]")) - expect_equal( - ret$matching, - c("alpha", "beta[1]", "beta[2]", "beta[3]") - ) - expect_equal(length(ret$not_found), 0) -}) - -test_that("check_ebfmi works", { +# if (not_on_cran()) { +# set_cmdstan_path() +# fit_mcmc <- testing_fit("logistic", method = "sample", +# seed = 123, chains = 2) +# } +# +# test_that("check_divergences() works", { +# skip_on_cran() +# csv_files <- c(test_path("resources", "csv", "model1-2-no-warmup.csv")) +# csv_output <- read_cmdstan_csv(csv_files) +# output <- "14 of 100 \\(14.0%\\) transitions ended with a divergence." +# expect_message(check_divergences(csv_output$post_warmup_sampler_diagnostics), output) +# +# csv_files <- c(test_path("resources", "csv", "model1-2-no-warmup.csv"), +# test_path("resources", "csv", "model1-2-no-warmup.csv")) +# csv_output <- read_cmdstan_csv(csv_files) +# output <- "28 of 200 \\(14.0%\\) transitions ended with a divergence." +# expect_message(check_divergences(csv_output$post_warmup_sampler_diagnostics), output) +# +# csv_files <- c(test_path("resources", "csv", "model1-2-warmup.csv")) +# csv_output <- read_cmdstan_csv(csv_files) +# output <- "1 of 100 \\(1.0%\\) transitions ended with a divergence." +# expect_message(check_divergences(csv_output$post_warmup_sampler_diagnostics), output) +# +# +# fit_wramup_no_samples <- testing_fit("logistic", method = "sample", +# seed = 123, chains = 1, +# iter_sampling = 0, +# iter_warmup = 10, +# save_warmup = TRUE, +# validate_csv = FALSE) +# csv_output <- read_cmdstan_csv(fit_wramup_no_samples$output_files()) +# expect_message(check_divergences(csv_output$post_warmup_sampler_diagnostics), regexp = NA) +# }) +# +# test_that("check_sampler_transitions_treedepth() works", { +# skip_on_cran() +# csv_files <- c(test_path("resources", "csv", "model1-2-no-warmup.csv")) +# csv_output <- read_cmdstan_csv(csv_files) +# output <- "16 of 100 \\(16.0%\\) transitions hit the maximum treedepth limit of 5 or 2\\^5-1 leapfrog steps." +# expect_message( +# check_sampler_transitions_treedepth( +# csv_output$post_warmup_sampler_diagnostics, +# csv_output$metadata), +# output +# ) +# +# csv_files <- c(test_path("resources", "csv", "model1-2-no-warmup.csv"), +# test_path("resources", "csv", "model1-2-no-warmup.csv")) +# csv_output <- read_cmdstan_csv(csv_files) +# output <- "32 of 200 \\(16.0%\\) transitions hit the maximum treedepth limit of 5 or 2\\^5-1 leapfrog steps." +# expect_message( +# check_sampler_transitions_treedepth( +# csv_output$post_warmup_sampler_diagnostics, +# csv_output$metadata), +# output +# ) +# +# csv_files <- c(test_path("resources", "csv", "model1-2-warmup.csv")) +# csv_output <- read_cmdstan_csv(csv_files) +# output <- "1 of 100 \\(1.0%\\) transitions hit the maximum treedepth limit of 5 or 2\\^5-1 leapfrog steps." +# expect_message( +# check_sampler_transitions_treedepth( +# csv_output$post_warmup_sampler_diagnostics, +# csv_output$metadata), +# output +# ) +# }) +# +# test_that("cmdstan_summary works if bin/stansummary deleted file", { +# skip_on_cran() +# delete_and_run <- function() { +# file.remove(file.path(cmdstan_path(), "bin", cmdstan_ext("stansummary"))) +# fit_mcmc$cmdstan_summary() +# } +# expect_output(delete_and_run(), "Inference for Stan model: logistic_model\\n2 chains: each with iter") +# }) +# +# test_that("cmdstan_diagnose works if bin/diagnose deleted file", { +# skip_on_cran() +# delete_and_run <- function() { +# file.remove(file.path(cmdstan_path(), "bin", cmdstan_ext("diagnose"))) +# fit_mcmc$cmdstan_diagnose() +# } +# expect_output(delete_and_run(), "Checking sampler transitions treedepth") +# }) +# +# test_that("repair_path() fixes slashes", { +# # all slashes should be single "/", and no trailing slash +# expect_equal(repair_path("a//b\\c/"), "a/b/c") +# }) +# +# test_that("repair_path works with zero length path or non-string path", { +# expect_equal(repair_path(""), "") +# expect_equal(repair_path(5), 5) +# }) +# +# test_that("list_to_array works with empty list", { +# expect_equal(list_to_array(list()), NULL) +# }) +# +# test_that("list_to_array fails for non-numeric values", { +# expect_error(list_to_array(list(k = "test"), name = "test-list"), +# "All elements in list 'test-list' must be numeric!") +# }) +# +# test_that("cmdstan_make_local() works", { +# exisiting_make_local <- cmdstan_make_local() +# make_local_path <- file.path(cmdstan_path(), "make", "local") +# if (file.exists(make_local_path)) { +# file.remove(make_local_path) +# } +# expect_equal(cmdstan_make_local(), NULL) +# cpp_options = list( +# "CXX" = "clang++", +# "CXXFLAGS+= -march-native", +# TEST1 = TRUE, +# "TEST2" = FALSE +# ) +# expect_equal(cmdstan_make_local(cpp_options = cpp_options), +# c( +# "CXX=clang++", +# "CXXFLAGS+= -march-native", +# "TEST1=true", +# "TEST2=false" +# )) +# expect_equal(cmdstan_make_local(cpp_options = list("TEST3" = TRUE)), +# c( +# "CXX=clang++", +# "CXXFLAGS+= -march-native", +# "TEST1=true", +# "TEST2=false", +# "TEST3=true" +# )) +# expect_equal(cmdstan_make_local(cpp_options = list("TEST4" = TRUE), append = FALSE), +# c("TEST4=true")) +# cmdstan_make_local(cpp_options = as.list(exisiting_make_local), append = FALSE) +# }) +# +# test_that("matching_variables() works", { +# ret <- matching_variables(c("beta"), c("alpha", "beta[1]", "beta[2]", "beta[3]")) +# expect_equal( +# ret$matching, +# c("beta[1]", "beta[2]", "beta[3]") +# ) +# expect_equal(length(ret$not_found), 0) +# +# ret <- matching_variables(c("alpha"), c("alpha", "beta[1]", "beta[2]", "beta[3]")) +# expect_equal( +# ret$matching, +# c("alpha") +# ) +# expect_equal(length(ret$not_found), 0) +# +# ret <- matching_variables(c("alpha", "theta"), c("alpha", "beta[1]", "beta[2]", "beta[3]")) +# expect_equal( +# ret$matching, +# c("alpha") +# ) +# expect_equal( +# ret$not_found, +# c("theta") +# ) +# +# ret <- matching_variables(c("alpha", "beta"), c("alpha", "beta[1]", "beta[2]", "beta[3]")) +# expect_equal( +# ret$matching, +# c("alpha", "beta[1]", "beta[2]", "beta[3]") +# ) +# expect_equal(length(ret$not_found), 0) +# }) + +test_that("check_ebfmi and computing ebfmi works", { set.seed(1) energy_df <- data.frame("energy__" = rnorm(1000)) expect_error(suppressWarnings(check_ebfmi(posterior::as_draws(energy_df))), NA) + expect_error(suppressWarnings(ebfmi(posterior::as_draws(energy_df))), NA) energy_df[1] <- 0 for(i in 1:999){ energy_df$energy__[i+1] <- energy_df$energy__[i] + rnorm(1, 0, 0.01) @@ -184,12 +185,18 @@ test_that("check_ebfmi works", { expect_message(check_ebfmi(energy_df), "fraction of missing information \\(E-BFMI\\) less than") energy_vec <- energy_df$energy__ check_val <- (sum(diff(energy_vec)^2) / length(energy_vec)) / stats::var(energy_vec) - expect_equal(as.numeric(check_ebfmi(energy_df)), check_val) - expect_equal(as.numeric(check_ebfmi(posterior::as_draws_array(energy_df))), check_val) - expect_equal(as.numeric(check_ebfmi(posterior::as_draws_list(energy_df))), check_val) - expect_equal(as.numeric(check_ebfmi(posterior::as_draws_matrix(energy_df))), check_val) + expect_equal(as.numeric(ebfmi(energy_df)), check_val) + expect_equal(as.numeric(ebfmi(posterior::as_draws_array(energy_df))), check_val) + expect_equal(as.numeric(ebfmi(posterior::as_draws_list(energy_df))), check_val) + expect_equal(as.numeric(ebfmi(posterior::as_draws_matrix(energy_df))), check_val) energy_df <- posterior::as_draws(data.frame("energy__" = 0)) - expect_warning(check_ebfmi(energy_df), "E-BFMI is undefined for posterior chains of length less than 3") + expect_warning(check_ebfmi(energy_df), "E-BFMI is undefined for posterior chains of length less than 3.") + expect_warning(ebfmi(energy_df), "E-BFMI is undefined for posterior chains of length less than 3.") + + energy_df <- posterior::as_draws(data.frame("somethingelse" = 0)) + expect_warning(check_ebfmi(energy_df), "E-BFMI not computed as the 'energy__' diagnostic could not be located.") + expect_warning(ebfmi(energy_df), "E-BFMI not computed as the 'energy__' diagnostic could not be located.") + }) test_that("require_suggested_package() works", { From 1cef6c10e82bf08a53945e2b4443e08e7ba8c941 Mon Sep 17 00:00:00 2001 From: Rok Cesnovar Date: Mon, 19 Jul 2021 14:49:13 +0200 Subject: [PATCH 22/44] updated NEWS.md --- NEWS.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/NEWS.md b/NEWS.md index 4d76ee399..474a0e430 100644 --- a/NEWS.md +++ b/NEWS.md @@ -23,6 +23,8 @@ multiple times with the same code. (#495, @martinmodrak) * `write_stan_json()` now handles data of class `"table"`. Tables are converted to vector, matrix, or array depending on the dimensions of the table. (#528) +* Added E-BFMI checks that run automatically post sampling. (#500, @jsocolar) + # cmdstanr 0.4.0 ### Bug fixes From c33ebeb9164d3911ca9c5e3f93604e8e5169cddf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rok=20=C4=8Ce=C5=A1novar?= Date: Wed, 25 Aug 2021 09:46:58 +0200 Subject: [PATCH 23/44] Update test-utils.R --- tests/testthat/test-utils.R | 342 ++++++++++++++++++------------------ 1 file changed, 171 insertions(+), 171 deletions(-) diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R index 197f01cd3..09e1e9bcd 100644 --- a/tests/testthat/test-utils.R +++ b/tests/testthat/test-utils.R @@ -1,176 +1,176 @@ context("utils") -# if (not_on_cran()) { -# set_cmdstan_path() -# fit_mcmc <- testing_fit("logistic", method = "sample", -# seed = 123, chains = 2) -# } -# -# test_that("check_divergences() works", { -# skip_on_cran() -# csv_files <- c(test_path("resources", "csv", "model1-2-no-warmup.csv")) -# csv_output <- read_cmdstan_csv(csv_files) -# output <- "14 of 100 \\(14.0%\\) transitions ended with a divergence." -# expect_message(check_divergences(csv_output$post_warmup_sampler_diagnostics), output) -# -# csv_files <- c(test_path("resources", "csv", "model1-2-no-warmup.csv"), -# test_path("resources", "csv", "model1-2-no-warmup.csv")) -# csv_output <- read_cmdstan_csv(csv_files) -# output <- "28 of 200 \\(14.0%\\) transitions ended with a divergence." -# expect_message(check_divergences(csv_output$post_warmup_sampler_diagnostics), output) -# -# csv_files <- c(test_path("resources", "csv", "model1-2-warmup.csv")) -# csv_output <- read_cmdstan_csv(csv_files) -# output <- "1 of 100 \\(1.0%\\) transitions ended with a divergence." -# expect_message(check_divergences(csv_output$post_warmup_sampler_diagnostics), output) -# -# -# fit_wramup_no_samples <- testing_fit("logistic", method = "sample", -# seed = 123, chains = 1, -# iter_sampling = 0, -# iter_warmup = 10, -# save_warmup = TRUE, -# validate_csv = FALSE) -# csv_output <- read_cmdstan_csv(fit_wramup_no_samples$output_files()) -# expect_message(check_divergences(csv_output$post_warmup_sampler_diagnostics), regexp = NA) -# }) -# -# test_that("check_sampler_transitions_treedepth() works", { -# skip_on_cran() -# csv_files <- c(test_path("resources", "csv", "model1-2-no-warmup.csv")) -# csv_output <- read_cmdstan_csv(csv_files) -# output <- "16 of 100 \\(16.0%\\) transitions hit the maximum treedepth limit of 5 or 2\\^5-1 leapfrog steps." -# expect_message( -# check_sampler_transitions_treedepth( -# csv_output$post_warmup_sampler_diagnostics, -# csv_output$metadata), -# output -# ) -# -# csv_files <- c(test_path("resources", "csv", "model1-2-no-warmup.csv"), -# test_path("resources", "csv", "model1-2-no-warmup.csv")) -# csv_output <- read_cmdstan_csv(csv_files) -# output <- "32 of 200 \\(16.0%\\) transitions hit the maximum treedepth limit of 5 or 2\\^5-1 leapfrog steps." -# expect_message( -# check_sampler_transitions_treedepth( -# csv_output$post_warmup_sampler_diagnostics, -# csv_output$metadata), -# output -# ) -# -# csv_files <- c(test_path("resources", "csv", "model1-2-warmup.csv")) -# csv_output <- read_cmdstan_csv(csv_files) -# output <- "1 of 100 \\(1.0%\\) transitions hit the maximum treedepth limit of 5 or 2\\^5-1 leapfrog steps." -# expect_message( -# check_sampler_transitions_treedepth( -# csv_output$post_warmup_sampler_diagnostics, -# csv_output$metadata), -# output -# ) -# }) -# -# test_that("cmdstan_summary works if bin/stansummary deleted file", { -# skip_on_cran() -# delete_and_run <- function() { -# file.remove(file.path(cmdstan_path(), "bin", cmdstan_ext("stansummary"))) -# fit_mcmc$cmdstan_summary() -# } -# expect_output(delete_and_run(), "Inference for Stan model: logistic_model\\n2 chains: each with iter") -# }) -# -# test_that("cmdstan_diagnose works if bin/diagnose deleted file", { -# skip_on_cran() -# delete_and_run <- function() { -# file.remove(file.path(cmdstan_path(), "bin", cmdstan_ext("diagnose"))) -# fit_mcmc$cmdstan_diagnose() -# } -# expect_output(delete_and_run(), "Checking sampler transitions treedepth") -# }) -# -# test_that("repair_path() fixes slashes", { -# # all slashes should be single "/", and no trailing slash -# expect_equal(repair_path("a//b\\c/"), "a/b/c") -# }) -# -# test_that("repair_path works with zero length path or non-string path", { -# expect_equal(repair_path(""), "") -# expect_equal(repair_path(5), 5) -# }) -# -# test_that("list_to_array works with empty list", { -# expect_equal(list_to_array(list()), NULL) -# }) -# -# test_that("list_to_array fails for non-numeric values", { -# expect_error(list_to_array(list(k = "test"), name = "test-list"), -# "All elements in list 'test-list' must be numeric!") -# }) -# -# test_that("cmdstan_make_local() works", { -# exisiting_make_local <- cmdstan_make_local() -# make_local_path <- file.path(cmdstan_path(), "make", "local") -# if (file.exists(make_local_path)) { -# file.remove(make_local_path) -# } -# expect_equal(cmdstan_make_local(), NULL) -# cpp_options = list( -# "CXX" = "clang++", -# "CXXFLAGS+= -march-native", -# TEST1 = TRUE, -# "TEST2" = FALSE -# ) -# expect_equal(cmdstan_make_local(cpp_options = cpp_options), -# c( -# "CXX=clang++", -# "CXXFLAGS+= -march-native", -# "TEST1=true", -# "TEST2=false" -# )) -# expect_equal(cmdstan_make_local(cpp_options = list("TEST3" = TRUE)), -# c( -# "CXX=clang++", -# "CXXFLAGS+= -march-native", -# "TEST1=true", -# "TEST2=false", -# "TEST3=true" -# )) -# expect_equal(cmdstan_make_local(cpp_options = list("TEST4" = TRUE), append = FALSE), -# c("TEST4=true")) -# cmdstan_make_local(cpp_options = as.list(exisiting_make_local), append = FALSE) -# }) -# -# test_that("matching_variables() works", { -# ret <- matching_variables(c("beta"), c("alpha", "beta[1]", "beta[2]", "beta[3]")) -# expect_equal( -# ret$matching, -# c("beta[1]", "beta[2]", "beta[3]") -# ) -# expect_equal(length(ret$not_found), 0) -# -# ret <- matching_variables(c("alpha"), c("alpha", "beta[1]", "beta[2]", "beta[3]")) -# expect_equal( -# ret$matching, -# c("alpha") -# ) -# expect_equal(length(ret$not_found), 0) -# -# ret <- matching_variables(c("alpha", "theta"), c("alpha", "beta[1]", "beta[2]", "beta[3]")) -# expect_equal( -# ret$matching, -# c("alpha") -# ) -# expect_equal( -# ret$not_found, -# c("theta") -# ) -# -# ret <- matching_variables(c("alpha", "beta"), c("alpha", "beta[1]", "beta[2]", "beta[3]")) -# expect_equal( -# ret$matching, -# c("alpha", "beta[1]", "beta[2]", "beta[3]") -# ) -# expect_equal(length(ret$not_found), 0) -# }) +if (not_on_cran()) { + set_cmdstan_path() + fit_mcmc <- testing_fit("logistic", method = "sample", + seed = 123, chains = 2) +} + +test_that("check_divergences() works", { + skip_on_cran() + csv_files <- c(test_path("resources", "csv", "model1-2-no-warmup.csv")) + csv_output <- read_cmdstan_csv(csv_files) + output <- "14 of 100 \\(14.0%\\) transitions ended with a divergence." + expect_message(check_divergences(csv_output$post_warmup_sampler_diagnostics), output) + + csv_files <- c(test_path("resources", "csv", "model1-2-no-warmup.csv"), + test_path("resources", "csv", "model1-2-no-warmup.csv")) + csv_output <- read_cmdstan_csv(csv_files) + output <- "28 of 200 \\(14.0%\\) transitions ended with a divergence." + expect_message(check_divergences(csv_output$post_warmup_sampler_diagnostics), output) + + csv_files <- c(test_path("resources", "csv", "model1-2-warmup.csv")) + csv_output <- read_cmdstan_csv(csv_files) + output <- "1 of 100 \\(1.0%\\) transitions ended with a divergence." + expect_message(check_divergences(csv_output$post_warmup_sampler_diagnostics), output) + + + fit_wramup_no_samples <- testing_fit("logistic", method = "sample", + seed = 123, chains = 1, + iter_sampling = 0, + iter_warmup = 10, + save_warmup = TRUE, + validate_csv = FALSE) + csv_output <- read_cmdstan_csv(fit_wramup_no_samples$output_files()) + expect_message(check_divergences(csv_output$post_warmup_sampler_diagnostics), regexp = NA) +}) + +test_that("check_sampler_transitions_treedepth() works", { + skip_on_cran() + csv_files <- c(test_path("resources", "csv", "model1-2-no-warmup.csv")) + csv_output <- read_cmdstan_csv(csv_files) + output <- "16 of 100 \\(16.0%\\) transitions hit the maximum treedepth limit of 5 or 2\\^5-1 leapfrog steps." + expect_message( + check_sampler_transitions_treedepth( + csv_output$post_warmup_sampler_diagnostics, + csv_output$metadata), + output + ) + + csv_files <- c(test_path("resources", "csv", "model1-2-no-warmup.csv"), + test_path("resources", "csv", "model1-2-no-warmup.csv")) + csv_output <- read_cmdstan_csv(csv_files) + output <- "32 of 200 \\(16.0%\\) transitions hit the maximum treedepth limit of 5 or 2\\^5-1 leapfrog steps." + expect_message( + check_sampler_transitions_treedepth( + csv_output$post_warmup_sampler_diagnostics, + csv_output$metadata), + output + ) + + csv_files <- c(test_path("resources", "csv", "model1-2-warmup.csv")) + csv_output <- read_cmdstan_csv(csv_files) + output <- "1 of 100 \\(1.0%\\) transitions hit the maximum treedepth limit of 5 or 2\\^5-1 leapfrog steps." + expect_message( + check_sampler_transitions_treedepth( + csv_output$post_warmup_sampler_diagnostics, + csv_output$metadata), + output + ) +}) + +test_that("cmdstan_summary works if bin/stansummary deleted file", { + skip_on_cran() + delete_and_run <- function() { + file.remove(file.path(cmdstan_path(), "bin", cmdstan_ext("stansummary"))) + fit_mcmc$cmdstan_summary() + } + expect_output(delete_and_run(), "Inference for Stan model: logistic_model\\n2 chains: each with iter") +}) + +test_that("cmdstan_diagnose works if bin/diagnose deleted file", { + skip_on_cran() + delete_and_run <- function() { + file.remove(file.path(cmdstan_path(), "bin", cmdstan_ext("diagnose"))) + fit_mcmc$cmdstan_diagnose() + } + expect_output(delete_and_run(), "Checking sampler transitions treedepth") +}) + +test_that("repair_path() fixes slashes", { + # all slashes should be single "/", and no trailing slash + expect_equal(repair_path("a//b\\c/"), "a/b/c") +}) + +test_that("repair_path works with zero length path or non-string path", { + expect_equal(repair_path(""), "") + expect_equal(repair_path(5), 5) +}) + +test_that("list_to_array works with empty list", { + expect_equal(list_to_array(list()), NULL) +}) + +test_that("list_to_array fails for non-numeric values", { + expect_error(list_to_array(list(k = "test"), name = "test-list"), + "All elements in list 'test-list' must be numeric!") +}) + +test_that("cmdstan_make_local() works", { + exisiting_make_local <- cmdstan_make_local() + make_local_path <- file.path(cmdstan_path(), "make", "local") + if (file.exists(make_local_path)) { + file.remove(make_local_path) + } + expect_equal(cmdstan_make_local(), NULL) + cpp_options = list( + "CXX" = "clang++", + "CXXFLAGS+= -march-native", + TEST1 = TRUE, + "TEST2" = FALSE + ) + expect_equal(cmdstan_make_local(cpp_options = cpp_options), + c( + "CXX=clang++", + "CXXFLAGS+= -march-native", + "TEST1=true", + "TEST2=false" + )) + expect_equal(cmdstan_make_local(cpp_options = list("TEST3" = TRUE)), + c( + "CXX=clang++", + "CXXFLAGS+= -march-native", + "TEST1=true", + "TEST2=false", + "TEST3=true" + )) + expect_equal(cmdstan_make_local(cpp_options = list("TEST4" = TRUE), append = FALSE), + c("TEST4=true")) + cmdstan_make_local(cpp_options = as.list(exisiting_make_local), append = FALSE) +}) + +test_that("matching_variables() works", { + ret <- matching_variables(c("beta"), c("alpha", "beta[1]", "beta[2]", "beta[3]")) + expect_equal( + ret$matching, + c("beta[1]", "beta[2]", "beta[3]") + ) + expect_equal(length(ret$not_found), 0) + + ret <- matching_variables(c("alpha"), c("alpha", "beta[1]", "beta[2]", "beta[3]")) + expect_equal( + ret$matching, + c("alpha") + ) + expect_equal(length(ret$not_found), 0) + + ret <- matching_variables(c("alpha", "theta"), c("alpha", "beta[1]", "beta[2]", "beta[3]")) + expect_equal( + ret$matching, + c("alpha") + ) + expect_equal( + ret$not_found, + c("theta") + ) + + ret <- matching_variables(c("alpha", "beta"), c("alpha", "beta[1]", "beta[2]", "beta[3]")) + expect_equal( + ret$matching, + c("alpha", "beta[1]", "beta[2]", "beta[3]") + ) + expect_equal(length(ret$not_found), 0) +}) test_that("check_ebfmi and computing ebfmi works", { set.seed(1) From 0297c96b9b4d269ec32f457236a28280cadb6de5 Mon Sep 17 00:00:00 2001 From: jgabry Date: Tue, 2 Nov 2021 09:20:50 -0600 Subject: [PATCH 24/44] minor edits --- R/fit.R | 6 ++++-- R/utils.R | 13 ++++++++----- tests/testthat/test-utils.R | 8 ++++---- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/R/fit.R b/R/fit.R index bb72892e1..e652f0ffe 100644 --- a/R/fit.R +++ b/R/fit.R @@ -858,8 +858,10 @@ CmdStanMCMC <- R6::R6Class( } else { if (self$runset$args$validate_csv) { fixed_param <- runset$args$method_args$fixed_param - private$read_csv_(variables = "", - sampler_diagnostics = if (!fixed_param) c("treedepth__", "divergent__", "energy__") else "") + private$read_csv_( + variables = "", + sampler_diagnostics = if (!fixed_param) c("treedepth__", "divergent__", "energy__") else "" + ) if (!fixed_param) { check_divergences(private$sampler_diagnostics_) check_sampler_transitions_treedepth(private$sampler_diagnostics_, private$metadata_) diff --git a/R/utils.R b/R/utils.R index d6cd7136f..fbc8fb8e9 100644 --- a/R/utils.R +++ b/R/utils.R @@ -282,9 +282,9 @@ ebfmi <- function(post_warmup_sampler_diagnostics) { efbmi_val <- NULL if (!is.null(post_warmup_sampler_diagnostics)) { if (!("energy__" %in% posterior::variables(post_warmup_sampler_diagnostics))) { - warning("E-BFMI not computed as the 'energy__' diagnostic could not be located.", call. = FALSE) + warning("E-BFMI not computed because the 'energy__' diagnostic could not be located.", call. = FALSE) } else if (posterior::niterations(post_warmup_sampler_diagnostics) < 3) { - warning("E-BFMI is undefined for posterior chains of length less than 3.", call. = FALSE) + warning("E-BFMI not computed because it is undefined for posterior chains of length less than 3.", call. = FALSE) } else { energy <- posterior::extract_variable_matrix(post_warmup_sampler_diagnostics, "energy__") if (any(is.na(energy))) { @@ -302,9 +302,12 @@ ebfmi <- function(post_warmup_sampler_diagnostics) { check_ebfmi <- function(post_warmup_sampler_diagnostics, threshold = 0.2) { efbmi_val <- ebfmi(post_warmup_sampler_diagnostics) if (any(efbmi_val < threshold)) { - message(paste0(sum(efbmi_val < threshold), " of ", length(efbmi_val), " chains had energy-based Bayesian fraction ", - "of missing information (E-BFMI) less than ", threshold, ", which may indicate poor exploration of the ", - "posterior.")) + message( + "Warning: ", sum(efbmi_val < threshold), " of ", length(efbmi_val), + " chains had energy-based Bayesian fraction of missing information (E-BFMI)", + " less than ", threshold, ".", + "\nThis may indicate poor exploration of the posterior." + ) } invisible(NULL) } diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R index 09e1e9bcd..a6dc23a2c 100644 --- a/tests/testthat/test-utils.R +++ b/tests/testthat/test-utils.R @@ -190,12 +190,12 @@ test_that("check_ebfmi and computing ebfmi works", { expect_equal(as.numeric(ebfmi(posterior::as_draws_list(energy_df))), check_val) expect_equal(as.numeric(ebfmi(posterior::as_draws_matrix(energy_df))), check_val) energy_df <- posterior::as_draws(data.frame("energy__" = 0)) - expect_warning(check_ebfmi(energy_df), "E-BFMI is undefined for posterior chains of length less than 3.") - expect_warning(ebfmi(energy_df), "E-BFMI is undefined for posterior chains of length less than 3.") + expect_warning(check_ebfmi(energy_df), "E-BFMI not computed because it is undefined for posterior chains of length less than 3.") + expect_warning(ebfmi(energy_df), "E-BFMI not computed because it is undefined for posterior chains of length less than 3.") energy_df <- posterior::as_draws(data.frame("somethingelse" = 0)) - expect_warning(check_ebfmi(energy_df), "E-BFMI not computed as the 'energy__' diagnostic could not be located.") - expect_warning(ebfmi(energy_df), "E-BFMI not computed as the 'energy__' diagnostic could not be located.") + expect_warning(check_ebfmi(energy_df), "E-BFMI not computed because the 'energy__' diagnostic could not be located.") + expect_warning(ebfmi(energy_df), "E-BFMI not computed because the 'energy__' diagnostic could not be located.") }) From 21947a385ee182815c25bea7e8e8bfbec809ac2a Mon Sep 17 00:00:00 2001 From: jgabry Date: Wed, 3 Nov 2021 12:43:31 -0600 Subject: [PATCH 25/44] draft of diagnose_sampler method --- R/csv.R | 2 +- R/fit.R | 74 +++++++++++++++++++++++++++++++++++++ R/utils.R | 14 +++++-- tests/testthat/test-utils.R | 8 ++-- 4 files changed, 90 insertions(+), 8 deletions(-) diff --git a/R/csv.R b/R/csv.R index b33724a1a..cc0611eae 100644 --- a/R/csv.R +++ b/R/csv.R @@ -446,7 +446,7 @@ CmdStanMCMC_CSV <- R6::R6Class( initialize = function(csv_contents, files, check_diagnostics = TRUE) { if (check_diagnostics) { check_divergences(csv_contents$post_warmup_sampler_diagnostics) - check_sampler_transitions_treedepth(csv_contents$post_warmup_sampler_diagnostics, csv_contents$metadata) + check_max_treedepth(csv_contents$post_warmup_sampler_diagnostics, csv_contents$metadata) check_ebfmi(csv_contents$post_warmup_sampler_diagnostics) } private$output_files_ <- files diff --git a/R/fit.R b/R/fit.R index e652f0ffe..7118bc351 100644 --- a/R/fit.R +++ b/R/fit.R @@ -1109,6 +1109,80 @@ sampler_diagnostics <- function(inc_warmup = FALSE, format = getOption("cmdstanr } CmdStanMCMC$set("public", name = "sampler_diagnostics", value = sampler_diagnostics) +#' Warnings and summaries of sampler diagnostics +#' +#' @name fit-method-diagnose_sampler +#' @aliases diagnose_sampler +#' @description Warnings and summaries of sampler diagnostics. To instead get +#' the underlying values of the sampler diagnostics for each iteration and +#' chain use the [`$sampler_diagnostics()`][fit-method-sampler_diagnostics] +#' method. +#' +#' @param diagnostics (character vector) One or more diagnostics to check. The +#' currently supported diagnostics are `"divergences`, `"treedepth"`, and +#' `"ebfmi`. +#' @param quiet (logical) Should messages about the diagnostics be displayed? +#' The values of the diagnostics are always returned but if `quiet = FALSE` +#' (the default) the warning messages about the diagnostics are also +#' displayed. +#' +#' @return A list with as many named elements as `diagnostics` selected. The +#' possible elements and their values are: +#' * `"divergences"`: The number of divergences. +#' * `"max_treedepths"`: The number of times `max_treedepth` was hit. +#' * `"ebfmi"`: A vector of E-BFMI values, one per chain. +#' +#' @seealso [`CmdStanMCMC`] and the +#' [`$sampler_diagnostics()`][fit-method-sampler_diagnostics] method +#' +#' @examples +#' \dontrun{ +#' fit <- cmdstanr_example("schools") +#' fit$diagnose_sampler() +#' fit$diagnose_sampler(quiet = TRUE) +#' } +#' +diagnose_sampler <- function(diagnostics = c("divergences", "treedepth", "ebfmi"), quiet = FALSE) { + if (is.null(private$sampler_diagnostics_) && + !length(self$output_files(include_failed = FALSE))) { + stop("No chains finished successfully. Unable to retrieve the sampler diagnostics.", call. = FALSE) + } + diagnostics <- match.arg( + diagnostics, + choices = available_diagnostics(), + several.ok = TRUE + ) + post_warmup_sampler_diagnostics <- self$sampler_diagnostics(inc_warmup = FALSE) + out <- list() + if ("divergences" %in% diagnostics) { + if (quiet) { + divergences <- suppressMessages(check_divergences(post_warmup_sampler_diagnostics)) + } else { + divergences <- check_divergences(post_warmup_sampler_diagnostics) + } + out[["divergences"]] <- divergences + } + if ("treedepth" %in% diagnostics) { + if (quiet) { + max_treedepth_hit <- suppressMessages(check_max_treedepth(post_warmup_sampler_diagnostics, self$metadata())) + } else { + max_treedepth_hit <- check_max_treedepth(post_warmup_sampler_diagnostics, self$metadata()) + } + out[["max_treedepths"]] <- max_treedepth_hit + } + if ("ebfmi" %in% diagnostics) { + if (quiet) { + ebfmi <- suppressMessages(check_ebfmi(post_warmup_sampler_diagnostics)) + } else { + ebfmi <- check_ebfmi(post_warmup_sampler_diagnostics) + } + out[["ebfmi"]] <- ebfmi %||% NA + } + out +} +CmdStanMCMC$set("public", name = "diagnose_sampler", value = diagnose_sampler) + + #' Extract inverse metric (mass matrix) after MCMC #' #' @name fit-method-inv_metric diff --git a/R/utils.R b/R/utils.R index fbc8fb8e9..b75450ec6 100644 --- a/R/utils.R +++ b/R/utils.R @@ -239,6 +239,7 @@ set_num_threads <- function(num_threads) { # convergence checks ------------------------------------------------------ check_divergences <- function(post_warmup_sampler_diagnostics) { + num_of_divergences <- NULL if (!is.null(post_warmup_sampler_diagnostics)) { divergences <- posterior::extract_variable_matrix(post_warmup_sampler_diagnostics, "divergent__") num_of_draws <- length(divergences) @@ -257,9 +258,11 @@ check_divergences <- function(post_warmup_sampler_diagnostics) { ) } } + invisible(num_of_divergences) } -check_sampler_transitions_treedepth <- function(post_warmup_sampler_diagnostics, metadata) { +check_max_treedepth <- function(post_warmup_sampler_diagnostics, metadata) { + max_treedepth_hit <- NULL if (!is.null(post_warmup_sampler_diagnostics)) { treedepth <- posterior::extract_variable_matrix(post_warmup_sampler_diagnostics, "treedepth__") num_of_draws <- length(treedepth) @@ -276,6 +279,7 @@ check_sampler_transitions_treedepth <- function(post_warmup_sampler_diagnostics, ) } } + invisible(max_treedepth_hit) } ebfmi <- function(post_warmup_sampler_diagnostics) { @@ -306,12 +310,16 @@ check_ebfmi <- function(post_warmup_sampler_diagnostics, threshold = 0.2) { "Warning: ", sum(efbmi_val < threshold), " of ", length(efbmi_val), " chains had energy-based Bayesian fraction of missing information (E-BFMI)", " less than ", threshold, ".", - "\nThis may indicate poor exploration of the posterior." + "\nThis may indicate poor exploration of the posterior.\n" ) } - invisible(NULL) + invisible(efbmi_val) } +# used in various places to validate the selected diagnostics +available_diagnostics <- function() { + c("divergences", "treedepth", "ebfmi") +} # draws formatting -------------------------------------------------------- diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R index a6dc23a2c..f0e49446c 100644 --- a/tests/testthat/test-utils.R +++ b/tests/testthat/test-utils.R @@ -35,13 +35,13 @@ test_that("check_divergences() works", { expect_message(check_divergences(csv_output$post_warmup_sampler_diagnostics), regexp = NA) }) -test_that("check_sampler_transitions_treedepth() works", { +test_that("check_max_treedepth() works", { skip_on_cran() csv_files <- c(test_path("resources", "csv", "model1-2-no-warmup.csv")) csv_output <- read_cmdstan_csv(csv_files) output <- "16 of 100 \\(16.0%\\) transitions hit the maximum treedepth limit of 5 or 2\\^5-1 leapfrog steps." expect_message( - check_sampler_transitions_treedepth( + check_max_treedepth( csv_output$post_warmup_sampler_diagnostics, csv_output$metadata), output @@ -52,7 +52,7 @@ test_that("check_sampler_transitions_treedepth() works", { csv_output <- read_cmdstan_csv(csv_files) output <- "32 of 200 \\(16.0%\\) transitions hit the maximum treedepth limit of 5 or 2\\^5-1 leapfrog steps." expect_message( - check_sampler_transitions_treedepth( + check_max_treedepth( csv_output$post_warmup_sampler_diagnostics, csv_output$metadata), output @@ -62,7 +62,7 @@ test_that("check_sampler_transitions_treedepth() works", { csv_output <- read_cmdstan_csv(csv_files) output <- "1 of 100 \\(1.0%\\) transitions hit the maximum treedepth limit of 5 or 2\\^5-1 leapfrog steps." expect_message( - check_sampler_transitions_treedepth( + check_max_treedepth( csv_output$post_warmup_sampler_diagnostics, csv_output$metadata), output From e51aec2192768116d9e23d0f76cf8a6855145617 Mon Sep 17 00:00:00 2001 From: jgabry Date: Wed, 3 Nov 2021 12:43:43 -0600 Subject: [PATCH 26/44] deprecate validate_csv in favor of diagnostics argument --- R/args.R | 15 +++++++-- R/fit.R | 9 +++--- R/model.R | 44 +++++++++++++++++++++------ man-roxygen/model-sample-args.R | 27 ++++++++-------- man/fit-method-diagnose_sampler.Rd | 49 ++++++++++++++++++++++++++++++ man/model-method-sample.Rd | 18 ++++++----- man/model-method-sample_mpi.Rd | 20 +++++++----- 7 files changed, 137 insertions(+), 45 deletions(-) create mode 100644 man/fit-method-diagnose_sampler.Rd diff --git a/R/args.R b/R/args.R index 8db0d6ed0..a51255dc6 100644 --- a/R/args.R +++ b/R/args.R @@ -36,7 +36,6 @@ CmdStanArgs <- R6::R6Class( refresh = NULL, output_dir = NULL, output_basename = NULL, - validate_csv = TRUE, sig_figs = NULL, opencl_ids = NULL, model_variables = NULL) { @@ -52,7 +51,6 @@ CmdStanArgs <- R6::R6Class( self$method_args <- method_args self$method <- self$method_args$method self$save_latent_dynamics <- save_latent_dynamics - self$validate_csv <- validate_csv self$using_tempdir <- is.null(output_dir) if (getRversion() < "3.5.0") { self$output_dir <- output_dir %||% tempdir() @@ -196,7 +194,8 @@ SampleArgs <- R6::R6Class( init_buffer = NULL, term_buffer = NULL, window = NULL, - fixed_param = FALSE) { + fixed_param = FALSE, + diagnostics = NULL) { self$iter_warmup <- iter_warmup self$iter_sampling <- iter_sampling @@ -209,6 +208,11 @@ SampleArgs <- R6::R6Class( self$metric <- metric self$inv_metric <- inv_metric self$fixed_param <- fixed_param + self$diagnostics <- diagnostics + if (identical(self$diagnostics, "")) { + self$diagnostics <- NULL + } + if (!is.null(inv_metric)) { if (!is.null(metric_file)) { stop("Only one of inv_metric and metric_file can be specified.", @@ -636,6 +640,11 @@ validate_sample_args <- function(self, num_procs) { validate_metric(self$metric) validate_metric_file(self$metric_file, num_procs) + checkmate::assert_character(self$diagnostics, null.ok = TRUE, any.missing = FALSE) + if (!is.null(self$diagnostics)) { + checkmate::assert_subset(self$diagnostics, empty.ok = FALSE, choices = available_diagnostics()) + } + invisible(TRUE) } diff --git a/R/fit.R b/R/fit.R index 7118bc351..2fd132020 100644 --- a/R/fit.R +++ b/R/fit.R @@ -856,16 +856,15 @@ CmdStanMCMC <- R6::R6Class( warning("No chains finished successfully. Unable to retrieve the fit.", call. = FALSE) } else { - if (self$runset$args$validate_csv) { + if (!is.null(self$runset$args$method_args$diagnostics)) { + diagnostics <- self$runset$method_args$diagnostics fixed_param <- runset$args$method_args$fixed_param private$read_csv_( variables = "", - sampler_diagnostics = if (!fixed_param) c("treedepth__", "divergent__", "energy__") else "" + sampler_diagnostics = if (!fixed_param) diagnostics else "" ) if (!fixed_param) { - check_divergences(private$sampler_diagnostics_) - check_sampler_transitions_treedepth(private$sampler_diagnostics_, private$metadata_) - check_ebfmi(private$sampler_diagnostics_) + invisible(self$diagnose_sampler(diagnostics = diagnostics, quiet = FALSE)) } } } diff --git a/R/model.R b/R/model.R index 9a1153453..2f97b1a6b 100644 --- a/R/model.R +++ b/R/model.R @@ -772,7 +772,7 @@ CmdStanModel$set("public", name = "check_syntax", value = check_syntax) #' #' @template model-common-args #' @template model-sample-args -#' @param cores,num_cores,num_chains,num_warmup,num_samples,save_extra_diagnostics,max_depth,stepsize +#' @param cores,num_cores,num_chains,num_warmup,num_samples,save_extra_diagnostics,max_depth,stepsize,validate_csv #' Deprecated and will be removed in a future release. #' #' @return A [`CmdStanMCMC`] object. @@ -808,14 +808,15 @@ sample <- function(data = NULL, term_buffer = NULL, window = NULL, fixed_param = FALSE, - validate_csv = TRUE, show_messages = TRUE, + diagnostics = c("divergences", "treedepth", "ebfmi"), # deprecated cores = NULL, num_cores = NULL, num_chains = NULL, num_warmup = NULL, num_samples = NULL, + validate_csv = NULL, save_extra_diagnostics = NULL, max_depth = NULL, stepsize = NULL) { @@ -852,6 +853,17 @@ sample <- function(data = NULL, warning("'save_extra_diagnostics' is deprecated. Please use 'save_latent_dynamics' instead.") save_latent_dynamics <- save_extra_diagnostics } + if (!is.null(validate_csv)) { + warning("'validate_csv' is deprecated. Please set 'diagnostics=NULL' instead.") + if (is.logical(validate_csv)) { + if (validate_csv) { + diagnostics <- c("divergences", "treedepth", "ebfmi") + } else { + diagnostics <- NULL + } + } + } + if (cmdstan_version() >= "2.27.0" && !fixed_param) { if (self$has_stan_file() && file.exists(self$stan_file())) { if (!is.null(self$variables()) && length(self$variables()$parameters) == 0) { @@ -888,7 +900,8 @@ sample <- function(data = NULL, init_buffer = init_buffer, term_buffer = term_buffer, window = window, - fixed_param = fixed_param + fixed_param = fixed_param, + diagnostics = diagnostics ) args <- CmdStanArgs$new( method_args = sample_args, @@ -905,7 +918,6 @@ sample <- function(data = NULL, output_dir = output_dir, output_basename = output_basename, sig_figs = sig_figs, - validate_csv = validate_csv, opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()), model_variables = model_variables ) @@ -1000,8 +1012,22 @@ sample_mpi <- function(data = NULL, window = NULL, fixed_param = FALSE, sig_figs = NULL, - validate_csv = TRUE, - show_messages = TRUE) { + show_messages = TRUE, + diagnostics = c("divergences", "treedepth", "ebfmi"), + # deprecated + validate_csv = TRUE) { + + if (!is.null(validate_csv)) { + warning("'validate_csv' is deprecated. Please set 'diagnostics=NULL' instead.") + if (is.logical(validate_csv)) { + if (validate_csv) { + diagnostics <- c("divergences", "treedepth", "ebfmi") + } else { + diagnostics <- NULL + } + } + } + if (fixed_param) { chains <- 1 save_warmup <- FALSE @@ -1030,7 +1056,8 @@ sample_mpi <- function(data = NULL, init_buffer = init_buffer, term_buffer = term_buffer, window = window, - fixed_param = fixed_param + fixed_param = fixed_param, + diagnostics = diagnostics ) args <- CmdStanArgs$new( method_args = sample_args, @@ -1046,7 +1073,6 @@ sample_mpi <- function(data = NULL, refresh = refresh, output_dir = output_dir, output_basename = output_basename, - validate_csv = validate_csv, sig_figs = sig_figs, model_variables = model_variables ) @@ -1557,7 +1583,7 @@ model_variables <- function(stan_file, include_paths = NULL, allow_undefined = F allow_undefined_arg <- "--allow-undefined" } else { allow_undefined_arg <- NULL - } + } out_file <- tempfile(fileext = ".json") run_log <- processx::run( command = stanc_cmd(), diff --git a/man-roxygen/model-sample-args.R b/man-roxygen/model-sample-args.R index 632c20e4c..3d1e4233c 100644 --- a/man-roxygen/model-sample-args.R +++ b/man-roxygen/model-sample-args.R @@ -17,19 +17,6 @@ #' `parallel_chains*threads_per_chain`. For an example of using threading see #' the Stan case study [Reduce Sum: A Minimal #' Example](https://mc-stan.org/users/documentation/case-studies/reduce_sum_tutorial.html). -#' -#' @param show_messages (logical) When `TRUE` (the default), prints all -#' informational messages, for example rejection of the current proposal. -#' Disable if you wish silence these messages, but this is not recommended -#' unless you are very sure that the model is correct up to numerical error. -#' If the messages are silenced then the `$output()` method of the resulting -#' fit object can be used to display all the silenced messages. -#' -#' @param validate_csv (logical) When `TRUE` (the default), validate the -#' sampling results in the csv files. Disable if you wish to manually read in -#' the sampling results and validate them yourself, for example using -#' [read_cmdstan_csv()]. -#' #' @param iter_sampling (positive integer) The number of post-warmup iterations #' to run per chain. Note: in the CmdStan User's Guide this is referred to as #' `num_samples`. @@ -42,7 +29,6 @@ #' accessing the draws. #' @param thin (positive integer) The period between saved samples. This should #' typically be left at its default (no thinning) unless memory is a problem. -#' #' @param max_treedepth (positive integer) The maximum allowed tree depth for #' the NUTS engine. See the _Tree Depth_ section of the CmdStan User's Guide #' for more details. @@ -89,4 +75,17 @@ #' quantities block. If the parameters block is empty then using #' `fixed_param=TRUE` is mandatory. When `fixed_param=TRUE` the `chains` and #' `parallel_chains` arguments will be set to `1`. +#' @param show_messages (logical) When `TRUE` (the default), prints all +#' informational messages, for example rejection of the current proposal. +#' Disable if you wish silence these messages, but this is not recommended +#' unless you are very sure that the model is correct up to numerical error. +#' If the messages are silenced then the `$output()` method of the resulting +#' fit object can be used to display all the silenced messages. +#' @param diagnostics (character vector) The diagnostics to automatically check +#' and warn about after sampling. Setting this to an empty string `""` or +#' `NULL` can be used to prevent CmdStanR from automatically reading in the +#' sampler diagnostics from CSV if you wish to manually read in the results +#' and validate them yourself, for example using [read_cmdstan_csv()]. The +#' currently available diagnostics are `"divergences"`, `"treedepth"`, +#' `"ebfmi"` (the default is to check all of them). #' diff --git a/man/fit-method-diagnose_sampler.Rd b/man/fit-method-diagnose_sampler.Rd new file mode 100644 index 000000000..c58bc57e6 --- /dev/null +++ b/man/fit-method-diagnose_sampler.Rd @@ -0,0 +1,49 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/fit.R +\name{fit-method-diagnose_sampler} +\alias{fit-method-diagnose_sampler} +\alias{diagnose_sampler} +\title{Warnings and summaries of sampler diagnostics} +\usage{ +diagnose_sampler( + diagnostics = c("divergences", "treedepth", "ebfmi"), + quiet = FALSE +) +} +\arguments{ +\item{diagnostics}{(character vector) One or more diagnostics to check. The +currently supported diagnostics are \verb{"divergences}, \code{"treedepth"}, and +\verb{"ebfmi}.} + +\item{quiet}{(logical) Should messages about the diagnostics be displayed? +The values of the diagnostics are always returned but if \code{quiet = FALSE} +(the default) the warning messages about the diagnostics are also +displayed.} +} +\value{ +A list with as many named elements as \code{diagnostics} selected. The +possible elements and their values are: +\itemize{ +\item \code{"divergences"}: The number of divergences. +\item \code{"max_treedepths"}: The number of times \code{max_treedepth} was hit. +\item \code{"ebfmi"}: A vector of E-BFMI values, one per chain. +} +} +\description{ +Warnings and summaries of sampler diagnostics. To instead get +the underlying values of the sampler diagnostics for each iteration and +chain use the \code{\link[=fit-method-sampler_diagnostics]{$sampler_diagnostics()}} +method. +} +\examples{ +\dontrun{ +fit <- cmdstanr_example("schools") +fit$diagnose_sampler() +fit$diagnose_sampler(quiet = TRUE) +} + +} +\seealso{ +\code{\link{CmdStanMCMC}} and the +\code{\link[=fit-method-sampler_diagnostics]{$sampler_diagnostics()}} method +} diff --git a/man/model-method-sample.Rd b/man/model-method-sample.Rd index 8ffcfafa5..e3cadc650 100644 --- a/man/model-method-sample.Rd +++ b/man/model-method-sample.Rd @@ -34,13 +34,14 @@ sample( term_buffer = NULL, window = NULL, fixed_param = FALSE, - validate_csv = TRUE, show_messages = TRUE, + diagnostics = c("divergences", "treedepth", "ebfmi"), cores = NULL, num_cores = NULL, num_chains = NULL, num_warmup = NULL, num_samples = NULL, + validate_csv = NULL, save_extra_diagnostics = NULL, max_depth = NULL, stepsize = NULL @@ -236,11 +237,6 @@ quantities block. If the parameters block is empty then using \code{fixed_param=TRUE} is mandatory. When \code{fixed_param=TRUE} the \code{chains} and \code{parallel_chains} arguments will be set to \code{1}.} -\item{validate_csv}{(logical) When \code{TRUE} (the default), validate the -sampling results in the csv files. Disable if you wish to manually read in -the sampling results and validate them yourself, for example using -\code{\link[=read_cmdstan_csv]{read_cmdstan_csv()}}.} - \item{show_messages}{(logical) When \code{TRUE} (the default), prints all informational messages, for example rejection of the current proposal. Disable if you wish silence these messages, but this is not recommended @@ -248,7 +244,15 @@ unless you are very sure that the model is correct up to numerical error. If the messages are silenced then the \verb{$output()} method of the resulting fit object can be used to display all the silenced messages.} -\item{cores, num_cores, num_chains, num_warmup, num_samples, save_extra_diagnostics, max_depth, stepsize}{Deprecated and will be removed in a future release.} +\item{diagnostics}{(character vector) The diagnostics to automatically check +and warn about after sampling. Setting this to an empty string \code{""} or +\code{NULL} can be used to prevent CmdStanR from automatically reading in the +sampler diagnostics from CSV if you wish to manually read in the results +and validate them yourself, for example using \code{\link[=read_cmdstan_csv]{read_cmdstan_csv()}}. The +currently available diagnostics are \code{"divergences"}, \code{"treedepth"}, +\code{"ebfmi"} (the default is to check all of them).} + +\item{cores, num_cores, num_chains, num_warmup, num_samples, save_extra_diagnostics, max_depth, stepsize, validate_csv}{Deprecated and will be removed in a future release.} } \value{ A \code{\link{CmdStanMCMC}} object. diff --git a/man/model-method-sample_mpi.Rd b/man/model-method-sample_mpi.Rd index 844ad21fa..9fca93d9a 100644 --- a/man/model-method-sample_mpi.Rd +++ b/man/model-method-sample_mpi.Rd @@ -33,8 +33,9 @@ sample_mpi( window = NULL, fixed_param = FALSE, sig_figs = NULL, - validate_csv = TRUE, - show_messages = TRUE + show_messages = TRUE, + diagnostics = c("divergences", "treedepth", "ebfmi"), + validate_csv = TRUE ) } \arguments{ @@ -215,17 +216,22 @@ values with 6 significant figures. The upper limit for \code{sig_figs} is 18. Increasing this value will result in larger output CSV files and thus an increased usage of disk space.} -\item{validate_csv}{(logical) When \code{TRUE} (the default), validate the -sampling results in the csv files. Disable if you wish to manually read in -the sampling results and validate them yourself, for example using -\code{\link[=read_cmdstan_csv]{read_cmdstan_csv()}}.} - \item{show_messages}{(logical) When \code{TRUE} (the default), prints all informational messages, for example rejection of the current proposal. Disable if you wish silence these messages, but this is not recommended unless you are very sure that the model is correct up to numerical error. If the messages are silenced then the \verb{$output()} method of the resulting fit object can be used to display all the silenced messages.} + +\item{diagnostics}{(character vector) The diagnostics to automatically check +and warn about after sampling. Setting this to an empty string \code{""} or +\code{NULL} can be used to prevent CmdStanR from automatically reading in the +sampler diagnostics from CSV if you wish to manually read in the results +and validate them yourself, for example using \code{\link[=read_cmdstan_csv]{read_cmdstan_csv()}}. The +currently available diagnostics are \code{"divergences"}, \code{"treedepth"}, +\code{"ebfmi"} (the default is to check all of them).} + +\item{validate_csv}{Deprecated and will be removed in a future release.} } \value{ A \code{\link{CmdStanMCMC}} object. From d4c0d81100d2918d0ac7e1146a72e759210cb3f8 Mon Sep 17 00:00:00 2001 From: jgabry Date: Wed, 3 Nov 2021 12:54:47 -0600 Subject: [PATCH 27/44] Update utils.R --- R/utils.R | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/R/utils.R b/R/utils.R index b75450ec6..42aae4ddb 100644 --- a/R/utils.R +++ b/R/utils.R @@ -313,10 +313,11 @@ check_ebfmi <- function(post_warmup_sampler_diagnostics, threshold = 0.2) { "\nThis may indicate poor exploration of the posterior.\n" ) } - invisible(efbmi_val) + invisible(unname(efbmi_val)) } -# used in various places to validate the selected diagnostics +# used in various places (e.g., fit$diagnose_sampler() and validate_sample_args()) +# to validate the selected diagnostics available_diagnostics <- function() { c("divergences", "treedepth", "ebfmi") } From 005c634946a5fbe339de14dec99ebba443548d18 Mon Sep 17 00:00:00 2001 From: jgabry Date: Wed, 3 Nov 2021 17:18:53 -0600 Subject: [PATCH 28/44] return vectors of diagnostics, one element per chain --- R/fit.R | 6 +- R/model.R | 5 +- R/utils.R | 46 +++++++------- man-roxygen/model-sample-args.R | 12 +++- man/fit-method-diagnose_sampler.Rd | 6 +- man/model-method-sample.Rd | 15 ++++- man/model-method-sample_mpi.Rd | 12 +++- tests/testthat/test-fit-mcmc.R | 20 ++++++ tests/testthat/test-utils.R | 97 ++++++++++++++++++++---------- 9 files changed, 154 insertions(+), 65 deletions(-) diff --git a/R/fit.R b/R/fit.R index 2fd132020..a6db33b74 100644 --- a/R/fit.R +++ b/R/fit.R @@ -1127,9 +1127,9 @@ CmdStanMCMC$set("public", name = "sampler_diagnostics", value = sampler_diagnost #' #' @return A list with as many named elements as `diagnostics` selected. The #' possible elements and their values are: -#' * `"divergences"`: The number of divergences. -#' * `"max_treedepths"`: The number of times `max_treedepth` was hit. -#' * `"ebfmi"`: A vector of E-BFMI values, one per chain. +#' * `"divergences"`: A vector of the number of divergences per chain. +#' * `"max_treedepths"`: A vector of the number of times `max_treedepth` was hit per chain. +#' * `"ebfmi"`: A vector of E-BFMI values per chain. #' #' @seealso [`CmdStanMCMC`] and the #' [`$sampler_diagnostics()`][fit-method-sampler_diagnostics] method diff --git a/R/model.R b/R/model.R index 2f97b1a6b..c586d8230 100644 --- a/R/model.R +++ b/R/model.R @@ -770,6 +770,9 @@ CmdStanModel$set("public", name = "check_syntax", value = check_syntax) #' [CmdStan User’s Guide](https://mc-stan.org/docs/cmdstan-guide/) #' for more details. #' +#' After model fitting any diagnostics specified via the `diagnostics` +#' argument will be checked and warnings will be printed if warranted. +#' #' @template model-common-args #' @template model-sample-args #' @param cores,num_cores,num_chains,num_warmup,num_samples,save_extra_diagnostics,max_depth,stepsize,validate_csv @@ -854,7 +857,7 @@ sample <- function(data = NULL, save_latent_dynamics <- save_extra_diagnostics } if (!is.null(validate_csv)) { - warning("'validate_csv' is deprecated. Please set 'diagnostics=NULL' instead.") + warning("'validate_csv' is deprecated. Please use 'diagnostics' instead.") if (is.logical(validate_csv)) { if (validate_csv) { diagnostics <- c("divergences", "treedepth", "ebfmi") diff --git a/R/utils.R b/R/utils.R index 42aae4ddb..ae223228c 100644 --- a/R/utils.R +++ b/R/utils.R @@ -239,15 +239,16 @@ set_num_threads <- function(num_threads) { # convergence checks ------------------------------------------------------ check_divergences <- function(post_warmup_sampler_diagnostics) { - num_of_divergences <- NULL + num_divergences_per_chain <- NULL if (!is.null(post_warmup_sampler_diagnostics)) { divergences <- posterior::extract_variable_matrix(post_warmup_sampler_diagnostics, "divergent__") - num_of_draws <- length(divergences) - num_of_divergences <- sum(divergences) - if (!is.na(num_of_divergences) && num_of_divergences > 0) { - percentage_divergences <- 100 * num_of_divergences / num_of_draws + num_divergences_per_chain <- colSums(divergences) + num_divergences <- sum(num_divergences_per_chain) + num_draws <- length(divergences) + if (!is.na(num_divergences) && num_divergences > 0) { + percentage_divergences <- 100 * num_divergences / num_draws message( - "\nWarning: ", num_of_divergences, " of ", num_of_draws, + "\nWarning: ", num_divergences, " of ", num_draws, " (", (format(round(percentage_divergences, 0), nsmall = 1)), "%)", " transitions ended with a divergence.\n", "This may indicate insufficient exploration of the posterior distribution.\n", @@ -258,19 +259,20 @@ check_divergences <- function(post_warmup_sampler_diagnostics) { ) } } - invisible(num_of_divergences) + invisible(unname(num_divergences_per_chain)) } check_max_treedepth <- function(post_warmup_sampler_diagnostics, metadata) { - max_treedepth_hit <- NULL + num_max_treedepths_per_chain <- NULL if (!is.null(post_warmup_sampler_diagnostics)) { - treedepth <- posterior::extract_variable_matrix(post_warmup_sampler_diagnostics, "treedepth__") - num_of_draws <- length(treedepth) - max_treedepth_hit <- sum(treedepth >= metadata$max_treedepth) - if (!is.na(max_treedepth_hit) && max_treedepth_hit > 0) { - percentage_max_treedepth <- 100 * max_treedepth_hit / num_of_draws + treedepths <- posterior::extract_variable_matrix(post_warmup_sampler_diagnostics, "treedepth__") + num_max_treedepths_per_chain <- apply(treedepths, 2, function(x) sum(x >= metadata$max_treedepth)) + num_max_treedepths <- sum(num_max_treedepths_per_chain) + num_draws <- length(treedepths) + if (!is.na(num_max_treedepths) && num_max_treedepths > 0) { + percentage_max_treedepths <- 100 * num_max_treedepths / num_draws message( - max_treedepth_hit, " of ", num_of_draws, " (", (format(round(percentage_max_treedepth, 0), nsmall = 1)), "%)", + num_max_treedepths, " of ", num_draws, " (", (format(round(percentage_max_treedepths, 0), nsmall = 1)), "%)", " transitions hit the maximum treedepth limit of ", metadata$max_treedepth, " or 2^", metadata$max_treedepth, "-1 leapfrog steps.\n", "Trajectories that are prematurely terminated due to this limit will result in slow exploration.\n", @@ -279,11 +281,11 @@ check_max_treedepth <- function(post_warmup_sampler_diagnostics, metadata) { ) } } - invisible(max_treedepth_hit) + invisible(unname(num_max_treedepths_per_chain)) } ebfmi <- function(post_warmup_sampler_diagnostics) { - efbmi_val <- NULL + efbmi_per_chain <- NULL if (!is.null(post_warmup_sampler_diagnostics)) { if (!("energy__" %in% posterior::variables(post_warmup_sampler_diagnostics))) { warning("E-BFMI not computed because the 'energy__' diagnostic could not be located.", call. = FALSE) @@ -294,26 +296,26 @@ ebfmi <- function(post_warmup_sampler_diagnostics) { if (any(is.na(energy))) { warning("E-BFMI not computed because 'energy__' contains NAs.", call. = FALSE) } else { - efbmi_val <- apply(energy, 2, function(x) { + efbmi_per_chain <- apply(energy, 2, function(x) { (sum(diff(x)^2) / length(x)) / stats::var(x) }) } } } - efbmi_val + efbmi_per_chain } check_ebfmi <- function(post_warmup_sampler_diagnostics, threshold = 0.2) { - efbmi_val <- ebfmi(post_warmup_sampler_diagnostics) - if (any(efbmi_val < threshold)) { + efbmi_per_chain <- ebfmi(post_warmup_sampler_diagnostics) + if (any(efbmi_per_chain < threshold)) { message( - "Warning: ", sum(efbmi_val < threshold), " of ", length(efbmi_val), + "Warning: ", sum(efbmi_per_chain < threshold), " of ", length(efbmi_per_chain), " chains had energy-based Bayesian fraction of missing information (E-BFMI)", " less than ", threshold, ".", "\nThis may indicate poor exploration of the posterior.\n" ) } - invisible(unname(efbmi_val)) + invisible(unname(efbmi_per_chain)) } # used in various places (e.g., fit$diagnose_sampler() and validate_sample_args()) diff --git a/man-roxygen/model-sample-args.R b/man-roxygen/model-sample-args.R index 3d1e4233c..c7e2145c7 100644 --- a/man-roxygen/model-sample-args.R +++ b/man-roxygen/model-sample-args.R @@ -87,5 +87,15 @@ #' sampler diagnostics from CSV if you wish to manually read in the results #' and validate them yourself, for example using [read_cmdstan_csv()]. The #' currently available diagnostics are `"divergences"`, `"treedepth"`, -#' `"ebfmi"` (the default is to check all of them). +#' and `"ebfmi"` (the default is to check all of them). +#' +#' These diagnostics are also available after fitting. The +#' [`$sampler_diagnostics()`][fit-method-sampler_diagnostics] method provides +#' access the diagnostic values for each iteration and the +#' [`$diagnose_sampler()`][fit-method-diagnose_sampler] method provides +#' summaries of the diagnostics and can regenerate the warning messages. +#' +#' Diagnostics like R-hat and effective sample size are not currently +#' available via the `diagnostics` argument but can be checked after fitting +#' using the [`$summary()`][fit-method-summary] method. #' diff --git a/man/fit-method-diagnose_sampler.Rd b/man/fit-method-diagnose_sampler.Rd index c58bc57e6..2ce8d65d2 100644 --- a/man/fit-method-diagnose_sampler.Rd +++ b/man/fit-method-diagnose_sampler.Rd @@ -24,9 +24,9 @@ displayed.} A list with as many named elements as \code{diagnostics} selected. The possible elements and their values are: \itemize{ -\item \code{"divergences"}: The number of divergences. -\item \code{"max_treedepths"}: The number of times \code{max_treedepth} was hit. -\item \code{"ebfmi"}: A vector of E-BFMI values, one per chain. +\item \code{"divergences"}: A vector of the number of divergences per chain. +\item \code{"max_treedepths"}: A vector of the number of times \code{max_treedepth} was hit per chain. +\item \code{"ebfmi"}: A vector of E-BFMI values per chain. } } \description{ diff --git a/man/model-method-sample.Rd b/man/model-method-sample.Rd index e3cadc650..f35a59b71 100644 --- a/man/model-method-sample.Rd +++ b/man/model-method-sample.Rd @@ -250,7 +250,17 @@ and warn about after sampling. Setting this to an empty string \code{""} or sampler diagnostics from CSV if you wish to manually read in the results and validate them yourself, for example using \code{\link[=read_cmdstan_csv]{read_cmdstan_csv()}}. The currently available diagnostics are \code{"divergences"}, \code{"treedepth"}, -\code{"ebfmi"} (the default is to check all of them).} +and \code{"ebfmi"} (the default is to check all of them). + +These diagnostics are also available after fitting. The +\code{\link[=fit-method-sampler_diagnostics]{$sampler_diagnostics()}} method provides +access the diagnostic values for each iteration and the +\code{\link[=fit-method-diagnose_sampler]{$diagnose_sampler()}} method provides +summaries of the diagnostics and can regenerate the warning messages. + +Diagnostics like R-hat and effective sample size are not currently +available via the \code{diagnostics} argument but can be checked after fitting +using the \code{\link[=fit-method-summary]{$summary()}} method.} \item{cores, num_cores, num_chains, num_warmup, num_samples, save_extra_diagnostics, max_depth, stepsize, validate_csv}{Deprecated and will be removed in a future release.} } @@ -267,6 +277,9 @@ Any argument left as \code{NULL} will default to the default value used by the installed version of CmdStan. See the \href{https://mc-stan.org/docs/cmdstan-guide/}{CmdStan User’s Guide} for more details. + +After model fitting any diagnostics specified via the \code{diagnostics} +argument will be checked and warnings will be printed if warranted. } \examples{ \dontrun{ diff --git a/man/model-method-sample_mpi.Rd b/man/model-method-sample_mpi.Rd index 9fca93d9a..45d9c2bfd 100644 --- a/man/model-method-sample_mpi.Rd +++ b/man/model-method-sample_mpi.Rd @@ -229,7 +229,17 @@ and warn about after sampling. Setting this to an empty string \code{""} or sampler diagnostics from CSV if you wish to manually read in the results and validate them yourself, for example using \code{\link[=read_cmdstan_csv]{read_cmdstan_csv()}}. The currently available diagnostics are \code{"divergences"}, \code{"treedepth"}, -\code{"ebfmi"} (the default is to check all of them).} +and \code{"ebfmi"} (the default is to check all of them). + +These diagnostics are also available after fitting. The +\code{\link[=fit-method-sampler_diagnostics]{$sampler_diagnostics()}} method provides +access the diagnostic values for each iteration and the +\code{\link[=fit-method-diagnose_sampler]{$diagnose_sampler()}} method provides +summaries of the diagnostics and can regenerate the warning messages. + +Diagnostics like R-hat and effective sample size are not currently +available via the \code{diagnostics} argument but can be checked after fitting +using the \code{\link[=fit-method-summary]{$summary()}} method.} \item{validate_csv}{Deprecated and will be removed in a future release.} } diff --git a/tests/testthat/test-fit-mcmc.R b/tests/testthat/test-fit-mcmc.R index d695fb0a9..284878f01 100644 --- a/tests/testthat/test-fit-mcmc.R +++ b/tests/testthat/test-fit-mcmc.R @@ -340,3 +340,23 @@ test_that("draws() errors if invalid format", { "The supplied draws format is not valid" ) }) + +test_that("diagnose_sampler() works", { + fit <- suppressMessages(cmdstanr_example("schools")) + expect_message( + diagnostics <- fit$diagnose_sampler(), + "transitions ended with a divergence" + ) + expect_equal( + diagnostics$divergences, + suppressMessages(check_divergences(fit$sampler_diagnostics())) + ) + expect_equal( + diagnostics$max_treedepths, + suppressMessages(check_max_treedepth(fit$sampler_diagnostics(), fit$metadata())) + ) + expect_equal( + diagnostics$ebfmi, + suppressMessages(check_ebfmi(fit$sampler_diagnostics())) + ) +}) diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R index f0e49446c..e65029141 100644 --- a/tests/testthat/test-utils.R +++ b/tests/testthat/test-utils.R @@ -6,18 +6,29 @@ if (not_on_cran()) { seed = 123, chains = 2) } + +# diagnostic checks ------------------------------------------------------- + test_that("check_divergences() works", { skip_on_cran() csv_files <- c(test_path("resources", "csv", "model1-2-no-warmup.csv")) csv_output <- read_cmdstan_csv(csv_files) output <- "14 of 100 \\(14.0%\\) transitions ended with a divergence." - expect_message(check_divergences(csv_output$post_warmup_sampler_diagnostics), output) + expect_message(divs <- check_divergences(csv_output$post_warmup_sampler_diagnostics), output) + expect_equal(divs, 14) csv_files <- c(test_path("resources", "csv", "model1-2-no-warmup.csv"), test_path("resources", "csv", "model1-2-no-warmup.csv")) csv_output <- read_cmdstan_csv(csv_files) output <- "28 of 200 \\(14.0%\\) transitions ended with a divergence." - expect_message(check_divergences(csv_output$post_warmup_sampler_diagnostics), output) + expect_message(divs <- check_divergences(csv_output$post_warmup_sampler_diagnostics), output) + expect_equal(divs, c(14, 14)) + + # force different number of divergences per chain just to test + csv_output$post_warmup_sampler_diagnostics[1, 1:2, "divergent__"] <- c(0, 1) + output <- "27 of 200 \\(14.0%\\) transitions ended with a divergence." + expect_message(divs <- check_divergences(csv_output$post_warmup_sampler_diagnostics), output) + expect_equal(divs, c(13, 14)) csv_files <- c(test_path("resources", "csv", "model1-2-warmup.csv")) csv_output <- read_cmdstan_csv(csv_files) @@ -30,9 +41,10 @@ test_that("check_divergences() works", { iter_sampling = 0, iter_warmup = 10, save_warmup = TRUE, - validate_csv = FALSE) + diagnostics = "") csv_output <- read_cmdstan_csv(fit_wramup_no_samples$output_files()) - expect_message(check_divergences(csv_output$post_warmup_sampler_diagnostics), regexp = NA) + expect_message(divs <- check_divergences(csv_output$post_warmup_sampler_diagnostics), regexp = NA) + expect_null(divs) }) test_that("check_max_treedepth() works", { @@ -41,22 +53,35 @@ test_that("check_max_treedepth() works", { csv_output <- read_cmdstan_csv(csv_files) output <- "16 of 100 \\(16.0%\\) transitions hit the maximum treedepth limit of 5 or 2\\^5-1 leapfrog steps." expect_message( - check_max_treedepth( + max_tds <- check_max_treedepth( csv_output$post_warmup_sampler_diagnostics, csv_output$metadata), output ) + expect_equal(max_tds, 16) csv_files <- c(test_path("resources", "csv", "model1-2-no-warmup.csv"), test_path("resources", "csv", "model1-2-no-warmup.csv")) csv_output <- read_cmdstan_csv(csv_files) output <- "32 of 200 \\(16.0%\\) transitions hit the maximum treedepth limit of 5 or 2\\^5-1 leapfrog steps." expect_message( - check_max_treedepth( + max_tds <- check_max_treedepth( + csv_output$post_warmup_sampler_diagnostics, + csv_output$metadata), + output + ) + expect_equal(max_tds, c(16, 16)) + + # force different number of max treedepths per chain just to test + csv_output$post_warmup_sampler_diagnostics[1, 1:2, "treedepth__"] <- c(1, 15) + output <- "31 of 200 \\(16.0%\\) transitions hit the maximum treedepth limit of 5 or 2\\^5-1 leapfrog steps." + expect_message( + max_tds <- check_max_treedepth( csv_output$post_warmup_sampler_diagnostics, csv_output$metadata), output ) + expect_equal(max_tds, c(15, 16)) csv_files <- c(test_path("resources", "csv", "model1-2-warmup.csv")) csv_output <- read_cmdstan_csv(csv_files) @@ -69,6 +94,35 @@ test_that("check_max_treedepth() works", { ) }) +test_that("check_ebfmi and computing ebfmi works", { + set.seed(1) + energy_df <- data.frame("energy__" = rnorm(1000)) + expect_error(suppressWarnings(check_ebfmi(posterior::as_draws(energy_df))), NA) + expect_error(suppressWarnings(ebfmi(posterior::as_draws(energy_df))), NA) + energy_df[1] <- 0 + for(i in 1:999){ + energy_df$energy__[i+1] <- energy_df$energy__[i] + rnorm(1, 0, 0.01) + } + energy_df <- posterior::as_draws(energy_df) + expect_message(check_ebfmi(energy_df), "fraction of missing information \\(E-BFMI\\) less than") + energy_vec <- energy_df$energy__ + check_val <- (sum(diff(energy_vec)^2) / length(energy_vec)) / stats::var(energy_vec) + expect_equal(as.numeric(ebfmi(energy_df)), check_val) + expect_equal(as.numeric(ebfmi(posterior::as_draws_array(energy_df))), check_val) + expect_equal(as.numeric(ebfmi(posterior::as_draws_list(energy_df))), check_val) + expect_equal(as.numeric(ebfmi(posterior::as_draws_matrix(energy_df))), check_val) + energy_df <- posterior::as_draws(data.frame("energy__" = 0)) + expect_warning(check_ebfmi(energy_df), "E-BFMI not computed because it is undefined for posterior chains of length less than 3.") + expect_warning(ebfmi(energy_df), "E-BFMI not computed because it is undefined for posterior chains of length less than 3.") + + energy_df <- posterior::as_draws(data.frame("somethingelse" = 0)) + expect_warning(check_ebfmi(energy_df), "E-BFMI not computed because the 'energy__' diagnostic could not be located.") + expect_warning(ebfmi(energy_df), "E-BFMI not computed because the 'energy__' diagnostic could not be located.") +}) + + +# cmdstan utilities ------------------------------------------------------- + test_that("cmdstan_summary works if bin/stansummary deleted file", { skip_on_cran() delete_and_run <- function() { @@ -87,6 +141,9 @@ test_that("cmdstan_diagnose works if bin/diagnose deleted file", { expect_output(delete_and_run(), "Checking sampler transitions treedepth") }) + +# misc -------------------------------------------------------------------- + test_that("repair_path() fixes slashes", { # all slashes should be single "/", and no trailing slash expect_equal(repair_path("a//b\\c/"), "a/b/c") @@ -172,36 +229,10 @@ test_that("matching_variables() works", { expect_equal(length(ret$not_found), 0) }) -test_that("check_ebfmi and computing ebfmi works", { - set.seed(1) - energy_df <- data.frame("energy__" = rnorm(1000)) - expect_error(suppressWarnings(check_ebfmi(posterior::as_draws(energy_df))), NA) - expect_error(suppressWarnings(ebfmi(posterior::as_draws(energy_df))), NA) - energy_df[1] <- 0 - for(i in 1:999){ - energy_df$energy__[i+1] <- energy_df$energy__[i] + rnorm(1, 0, 0.01) - } - energy_df <- posterior::as_draws(energy_df) - expect_message(check_ebfmi(energy_df), "fraction of missing information \\(E-BFMI\\) less than") - energy_vec <- energy_df$energy__ - check_val <- (sum(diff(energy_vec)^2) / length(energy_vec)) / stats::var(energy_vec) - expect_equal(as.numeric(ebfmi(energy_df)), check_val) - expect_equal(as.numeric(ebfmi(posterior::as_draws_array(energy_df))), check_val) - expect_equal(as.numeric(ebfmi(posterior::as_draws_list(energy_df))), check_val) - expect_equal(as.numeric(ebfmi(posterior::as_draws_matrix(energy_df))), check_val) - energy_df <- posterior::as_draws(data.frame("energy__" = 0)) - expect_warning(check_ebfmi(energy_df), "E-BFMI not computed because it is undefined for posterior chains of length less than 3.") - expect_warning(ebfmi(energy_df), "E-BFMI not computed because it is undefined for posterior chains of length less than 3.") - - energy_df <- posterior::as_draws(data.frame("somethingelse" = 0)) - expect_warning(check_ebfmi(energy_df), "E-BFMI not computed because the 'energy__' diagnostic could not be located.") - expect_warning(ebfmi(energy_df), "E-BFMI not computed because the 'energy__' diagnostic could not be located.") - -}) - test_that("require_suggested_package() works", { expect_error( require_suggested_package("not_a_real_package"), "Please install the 'not_a_real_package' package to use this function." ) }) + From 8df1ed44be3e0c0b70322c41161a5b03d44bbd29 Mon Sep 17 00:00:00 2001 From: jgabry Date: Thu, 4 Nov 2021 11:41:28 -0600 Subject: [PATCH 29/44] diagnose_sampler -> diagnostic_summary --- R/fit.R | 22 +++++++++++++------ R/utils.R | 2 +- man-roxygen/model-sample-args.R | 2 +- ...er.Rd => fit-method-diagnostic_summary.Rd} | 20 ++++++++++++----- man/model-method-sample.Rd | 2 +- man/model-method-sample_mpi.Rd | 2 +- tests/testthat/test-fit-mcmc.R | 4 ++-- 7 files changed, 35 insertions(+), 19 deletions(-) rename man/{fit-method-diagnose_sampler.Rd => fit-method-diagnostic_summary.Rd} (70%) diff --git a/R/fit.R b/R/fit.R index a6db33b74..4d4fe0071 100644 --- a/R/fit.R +++ b/R/fit.R @@ -864,7 +864,7 @@ CmdStanMCMC <- R6::R6Class( sampler_diagnostics = if (!fixed_param) diagnostics else "" ) if (!fixed_param) { - invisible(self$diagnose_sampler(diagnostics = diagnostics, quiet = FALSE)) + invisible(self$diagnostic_summary(diagnostics = diagnostics, quiet = FALSE)) } } } @@ -1110,13 +1110,21 @@ CmdStanMCMC$set("public", name = "sampler_diagnostics", value = sampler_diagnost #' Warnings and summaries of sampler diagnostics #' -#' @name fit-method-diagnose_sampler -#' @aliases diagnose_sampler +#' @name fit-method-diagnostic_summary +#' @aliases diagnostic_summary #' @description Warnings and summaries of sampler diagnostics. To instead get #' the underlying values of the sampler diagnostics for each iteration and #' chain use the [`$sampler_diagnostics()`][fit-method-sampler_diagnostics] #' method. #' +#' This method is similar to [`$cmdstan_diagnose()`][fit-method-cmdstan_summary] +#' but everything is done in \R (rather than calling CmdStan utilities) and +#' the summaries are returned as a list (not just printed). +#' +#' Currently parameter-specific diagnostics like R-hat and effective sample +#' size are not handled by this method. Those diagnostics are provided via the +#' [`$summary()`][fit-method-summary] method. +#' #' @param diagnostics (character vector) One or more diagnostics to check. The #' currently supported diagnostics are `"divergences`, `"treedepth"`, and #' `"ebfmi`. @@ -1137,11 +1145,11 @@ CmdStanMCMC$set("public", name = "sampler_diagnostics", value = sampler_diagnost #' @examples #' \dontrun{ #' fit <- cmdstanr_example("schools") -#' fit$diagnose_sampler() -#' fit$diagnose_sampler(quiet = TRUE) +#' fit$diagnostic_summary() +#' fit$diagnostic_summary(quiet = TRUE) #' } #' -diagnose_sampler <- function(diagnostics = c("divergences", "treedepth", "ebfmi"), quiet = FALSE) { +diagnostic_summary <- function(diagnostics = c("divergences", "treedepth", "ebfmi"), quiet = FALSE) { if (is.null(private$sampler_diagnostics_) && !length(self$output_files(include_failed = FALSE))) { stop("No chains finished successfully. Unable to retrieve the sampler diagnostics.", call. = FALSE) @@ -1179,7 +1187,7 @@ diagnose_sampler <- function(diagnostics = c("divergences", "treedepth", "ebfmi" } out } -CmdStanMCMC$set("public", name = "diagnose_sampler", value = diagnose_sampler) +CmdStanMCMC$set("public", name = "diagnostic_summary", value = diagnostic_summary) #' Extract inverse metric (mass matrix) after MCMC diff --git a/R/utils.R b/R/utils.R index ae223228c..e696b68cb 100644 --- a/R/utils.R +++ b/R/utils.R @@ -318,7 +318,7 @@ check_ebfmi <- function(post_warmup_sampler_diagnostics, threshold = 0.2) { invisible(unname(efbmi_per_chain)) } -# used in various places (e.g., fit$diagnose_sampler() and validate_sample_args()) +# used in various places (e.g., fit$diagnostic_summary() and validate_sample_args()) # to validate the selected diagnostics available_diagnostics <- function() { c("divergences", "treedepth", "ebfmi") diff --git a/man-roxygen/model-sample-args.R b/man-roxygen/model-sample-args.R index c7e2145c7..6704ca682 100644 --- a/man-roxygen/model-sample-args.R +++ b/man-roxygen/model-sample-args.R @@ -92,7 +92,7 @@ #' These diagnostics are also available after fitting. The #' [`$sampler_diagnostics()`][fit-method-sampler_diagnostics] method provides #' access the diagnostic values for each iteration and the -#' [`$diagnose_sampler()`][fit-method-diagnose_sampler] method provides +#' [`$diagnostic_summary()`][fit-method-diagnostic_summary] method provides #' summaries of the diagnostics and can regenerate the warning messages. #' #' Diagnostics like R-hat and effective sample size are not currently diff --git a/man/fit-method-diagnose_sampler.Rd b/man/fit-method-diagnostic_summary.Rd similarity index 70% rename from man/fit-method-diagnose_sampler.Rd rename to man/fit-method-diagnostic_summary.Rd index 2ce8d65d2..1aaf2ce09 100644 --- a/man/fit-method-diagnose_sampler.Rd +++ b/man/fit-method-diagnostic_summary.Rd @@ -1,11 +1,11 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/fit.R -\name{fit-method-diagnose_sampler} -\alias{fit-method-diagnose_sampler} -\alias{diagnose_sampler} +\name{fit-method-diagnostic_summary} +\alias{fit-method-diagnostic_summary} +\alias{diagnostic_summary} \title{Warnings and summaries of sampler diagnostics} \usage{ -diagnose_sampler( +diagnostic_summary( diagnostics = c("divergences", "treedepth", "ebfmi"), quiet = FALSE ) @@ -34,12 +34,20 @@ Warnings and summaries of sampler diagnostics. To instead get the underlying values of the sampler diagnostics for each iteration and chain use the \code{\link[=fit-method-sampler_diagnostics]{$sampler_diagnostics()}} method. + +This method is similar to \code{\link[=fit-method-cmdstan_summary]{$cmdstan_diagnose()}} +but everything is done in \R (rather than calling CmdStan utilities) and +the summaries are returned as a list (not just printed). + +Currently parameter-specific diagnostics like R-hat and effective sample +size are not handled by this method. Those diagnostics are provided via the +\code{\link[=fit-method-summary]{$summary()}} method. } \examples{ \dontrun{ fit <- cmdstanr_example("schools") -fit$diagnose_sampler() -fit$diagnose_sampler(quiet = TRUE) +fit$diagnostic_summary() +fit$diagnostic_summary(quiet = TRUE) } } diff --git a/man/model-method-sample.Rd b/man/model-method-sample.Rd index f35a59b71..1fc377387 100644 --- a/man/model-method-sample.Rd +++ b/man/model-method-sample.Rd @@ -255,7 +255,7 @@ and \code{"ebfmi"} (the default is to check all of them). These diagnostics are also available after fitting. The \code{\link[=fit-method-sampler_diagnostics]{$sampler_diagnostics()}} method provides access the diagnostic values for each iteration and the -\code{\link[=fit-method-diagnose_sampler]{$diagnose_sampler()}} method provides +\code{\link[=fit-method-diagnostic_summary]{$diagnostic_summary()}} method provides summaries of the diagnostics and can regenerate the warning messages. Diagnostics like R-hat and effective sample size are not currently diff --git a/man/model-method-sample_mpi.Rd b/man/model-method-sample_mpi.Rd index 45d9c2bfd..95e708818 100644 --- a/man/model-method-sample_mpi.Rd +++ b/man/model-method-sample_mpi.Rd @@ -234,7 +234,7 @@ and \code{"ebfmi"} (the default is to check all of them). These diagnostics are also available after fitting. The \code{\link[=fit-method-sampler_diagnostics]{$sampler_diagnostics()}} method provides access the diagnostic values for each iteration and the -\code{\link[=fit-method-diagnose_sampler]{$diagnose_sampler()}} method provides +\code{\link[=fit-method-diagnostic_summary]{$diagnostic_summary()}} method provides summaries of the diagnostics and can regenerate the warning messages. Diagnostics like R-hat and effective sample size are not currently diff --git a/tests/testthat/test-fit-mcmc.R b/tests/testthat/test-fit-mcmc.R index 284878f01..31afa8d76 100644 --- a/tests/testthat/test-fit-mcmc.R +++ b/tests/testthat/test-fit-mcmc.R @@ -341,10 +341,10 @@ test_that("draws() errors if invalid format", { ) }) -test_that("diagnose_sampler() works", { +test_that("diagnostic_summary() works", { fit <- suppressMessages(cmdstanr_example("schools")) expect_message( - diagnostics <- fit$diagnose_sampler(), + diagnostics <- fit$diagnostic_summary(), "transitions ended with a divergence" ) expect_equal( From f29a1dbe4072858cc8c17a0b6833473654055a08 Mon Sep 17 00:00:00 2001 From: jgabry Date: Thu, 4 Nov 2021 11:53:07 -0600 Subject: [PATCH 30/44] update news --- NEWS.md | 7 +++++-- R/fit.R | 15 +++++++++------ 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/NEWS.md b/NEWS.md index 57162e7b0..be4e3be93 100644 --- a/NEWS.md +++ b/NEWS.md @@ -51,8 +51,6 @@ decimal points to fix issue with parsing large numbers. (#538) * Added a convenience argument `user_header` to `$compile()` and `cmdstan_model()` that simplifies the use of an external .hpp file to compile with the model. -* Added E-BFMI checks that run automatically post sampling. (#500, @jsocolar) - * New method `$code()` for all fitted model objects that returns the Stan code associated with the fitted model. (#575) @@ -62,6 +60,11 @@ recompilation of Stan models. (#580) * New methods for `posterior::as_draws()` for CmdStanR fitted model objects. These are just wrappers around the `$draws()` method provided for convenience. (#532) +* Added E-BFMI checks that run automatically post sampling. (#500, @jsocolar) + +* New method `$diagnostic_summary()` that summarizes the sampler diagnostics +(divergences, treedepth, ebfmi) and can regenerate the related warning messages. (#205) + # cmdstanr 0.4.0 ### Bug fixes diff --git a/R/fit.R b/R/fit.R index 4d4fe0071..ddb32c0a2 100644 --- a/R/fit.R +++ b/R/fit.R @@ -824,6 +824,7 @@ CmdStanFit$set("public", name = "code", value = code) #' |**Method**|**Description**| #' |:----------|:---------------| #' [`$summary()`][fit-method-summary] | Run [`posterior::summarise_draws()`][posterior::draws_summary]. | +#' [`$diagnostic_summary()`][fit-method-diagnostic_summary] | Get summaries of sampler diagnostics and warning messages. | #' [`$cmdstan_summary()`][fit-method-cmdstan_summary] | Run and print CmdStan's `bin/stansummary`. | #' [`$cmdstan_diagnose()`][fit-method-cmdstan_summary] | Run and print CmdStan's `bin/diagnose`. | #' [`$loo()`][fit-method-loo] | Run [loo::loo.array()] for approximate LOO-CV | @@ -1049,7 +1050,9 @@ CmdStanMCMC$set("public", name = "loo", value = loo) #' @name fit-method-sampler_diagnostics #' @aliases sampler_diagnostics #' @description Extract the values of sampler diagnostics for each iteration and -#' chain of MCMC. +#' chain of MCMC. To instead get summaries of these diagnostics and associated +#' warning messages use the +#' [`$diagnostic_summary()`][fit-method-diagnostic_summary] method. #' #' @param inc_warmup (logical) Should warmup draws be included? Defaults to `FALSE`. #' @param format (string) The draws format to return. See @@ -1108,7 +1111,7 @@ sampler_diagnostics <- function(inc_warmup = FALSE, format = getOption("cmdstanr } CmdStanMCMC$set("public", name = "sampler_diagnostics", value = sampler_diagnostics) -#' Warnings and summaries of sampler diagnostics +#' Sampler diagnostic summaries and warnings #' #' @name fit-method-diagnostic_summary #' @aliases diagnostic_summary @@ -1135,8 +1138,8 @@ CmdStanMCMC$set("public", name = "sampler_diagnostics", value = sampler_diagnost #' #' @return A list with as many named elements as `diagnostics` selected. The #' possible elements and their values are: -#' * `"divergences"`: A vector of the number of divergences per chain. -#' * `"max_treedepths"`: A vector of the number of times `max_treedepth` was hit per chain. +#' * `"num_divergent"`: A vector of the number of divergences per chain. +#' * `"num_max_treedepth"`: A vector of the number of times `max_treedepth` was hit per chain. #' * `"ebfmi"`: A vector of E-BFMI values per chain. #' #' @seealso [`CmdStanMCMC`] and the @@ -1167,7 +1170,7 @@ diagnostic_summary <- function(diagnostics = c("divergences", "treedepth", "ebfm } else { divergences <- check_divergences(post_warmup_sampler_diagnostics) } - out[["divergences"]] <- divergences + out[["num_divergent"]] <- divergences } if ("treedepth" %in% diagnostics) { if (quiet) { @@ -1175,7 +1178,7 @@ diagnostic_summary <- function(diagnostics = c("divergences", "treedepth", "ebfm } else { max_treedepth_hit <- check_max_treedepth(post_warmup_sampler_diagnostics, self$metadata()) } - out[["max_treedepths"]] <- max_treedepth_hit + out[["num_max_treedepth"]] <- max_treedepth_hit } if ("ebfmi" %in% diagnostics) { if (quiet) { From be9cbdf0a0db973dcb1392776c31d11390bacaa9 Mon Sep 17 00:00:00 2001 From: jgabry Date: Thu, 4 Nov 2021 11:53:48 -0600 Subject: [PATCH 31/44] regenerate doc --- NEWS.md | 2 +- R/fit.R | 18 +++++++++--------- R/utils.R | 2 +- man-roxygen/model-sample-args.R | 2 +- man/CmdStanMCMC.Rd | 1 + ...mmary.Rd => fit-method-diagnose_sampler.Rd} | 18 +++++++++--------- man/fit-method-sampler_diagnostics.Rd | 4 +++- man/model-method-sample.Rd | 2 +- man/model-method-sample_mpi.Rd | 2 +- tests/testthat/test-fit-mcmc.R | 4 ++-- 10 files changed, 29 insertions(+), 26 deletions(-) rename man/{fit-method-diagnostic_summary.Rd => fit-method-diagnose_sampler.Rd} (79%) diff --git a/NEWS.md b/NEWS.md index be4e3be93..7369dda5c 100644 --- a/NEWS.md +++ b/NEWS.md @@ -62,7 +62,7 @@ These are just wrappers around the `$draws()` method provided for convenience. ( * Added E-BFMI checks that run automatically post sampling. (#500, @jsocolar) -* New method `$diagnostic_summary()` that summarizes the sampler diagnostics +* New method `$diagnose_sampler()` that summarizes the sampler diagnostics (divergences, treedepth, ebfmi) and can regenerate the related warning messages. (#205) # cmdstanr 0.4.0 diff --git a/R/fit.R b/R/fit.R index ddb32c0a2..7fcee1e57 100644 --- a/R/fit.R +++ b/R/fit.R @@ -824,7 +824,7 @@ CmdStanFit$set("public", name = "code", value = code) #' |**Method**|**Description**| #' |:----------|:---------------| #' [`$summary()`][fit-method-summary] | Run [`posterior::summarise_draws()`][posterior::draws_summary]. | -#' [`$diagnostic_summary()`][fit-method-diagnostic_summary] | Get summaries of sampler diagnostics and warning messages. | +#' [`$diagnose_sampler()`][fit-method-diagnose_sampler] | Get summaries of sampler diagnostics and warning messages. | #' [`$cmdstan_summary()`][fit-method-cmdstan_summary] | Run and print CmdStan's `bin/stansummary`. | #' [`$cmdstan_diagnose()`][fit-method-cmdstan_summary] | Run and print CmdStan's `bin/diagnose`. | #' [`$loo()`][fit-method-loo] | Run [loo::loo.array()] for approximate LOO-CV | @@ -865,7 +865,7 @@ CmdStanMCMC <- R6::R6Class( sampler_diagnostics = if (!fixed_param) diagnostics else "" ) if (!fixed_param) { - invisible(self$diagnostic_summary(diagnostics = diagnostics, quiet = FALSE)) + invisible(self$diagnose_sampler(diagnostics = diagnostics, quiet = FALSE)) } } } @@ -1052,7 +1052,7 @@ CmdStanMCMC$set("public", name = "loo", value = loo) #' @description Extract the values of sampler diagnostics for each iteration and #' chain of MCMC. To instead get summaries of these diagnostics and associated #' warning messages use the -#' [`$diagnostic_summary()`][fit-method-diagnostic_summary] method. +#' [`$diagnose_sampler()`][fit-method-diagnose_sampler] method. #' #' @param inc_warmup (logical) Should warmup draws be included? Defaults to `FALSE`. #' @param format (string) The draws format to return. See @@ -1113,8 +1113,8 @@ CmdStanMCMC$set("public", name = "sampler_diagnostics", value = sampler_diagnost #' Sampler diagnostic summaries and warnings #' -#' @name fit-method-diagnostic_summary -#' @aliases diagnostic_summary +#' @name fit-method-diagnose_sampler +#' @aliases diagnose_sampler #' @description Warnings and summaries of sampler diagnostics. To instead get #' the underlying values of the sampler diagnostics for each iteration and #' chain use the [`$sampler_diagnostics()`][fit-method-sampler_diagnostics] @@ -1148,11 +1148,11 @@ CmdStanMCMC$set("public", name = "sampler_diagnostics", value = sampler_diagnost #' @examples #' \dontrun{ #' fit <- cmdstanr_example("schools") -#' fit$diagnostic_summary() -#' fit$diagnostic_summary(quiet = TRUE) +#' fit$diagnose_sampler() +#' fit$diagnose_sampler(quiet = TRUE) #' } #' -diagnostic_summary <- function(diagnostics = c("divergences", "treedepth", "ebfmi"), quiet = FALSE) { +diagnose_sampler <- function(diagnostics = c("divergences", "treedepth", "ebfmi"), quiet = FALSE) { if (is.null(private$sampler_diagnostics_) && !length(self$output_files(include_failed = FALSE))) { stop("No chains finished successfully. Unable to retrieve the sampler diagnostics.", call. = FALSE) @@ -1190,7 +1190,7 @@ diagnostic_summary <- function(diagnostics = c("divergences", "treedepth", "ebfm } out } -CmdStanMCMC$set("public", name = "diagnostic_summary", value = diagnostic_summary) +CmdStanMCMC$set("public", name = "diagnose_sampler", value = diagnose_sampler) #' Extract inverse metric (mass matrix) after MCMC diff --git a/R/utils.R b/R/utils.R index e696b68cb..ae223228c 100644 --- a/R/utils.R +++ b/R/utils.R @@ -318,7 +318,7 @@ check_ebfmi <- function(post_warmup_sampler_diagnostics, threshold = 0.2) { invisible(unname(efbmi_per_chain)) } -# used in various places (e.g., fit$diagnostic_summary() and validate_sample_args()) +# used in various places (e.g., fit$diagnose_sampler() and validate_sample_args()) # to validate the selected diagnostics available_diagnostics <- function() { c("divergences", "treedepth", "ebfmi") diff --git a/man-roxygen/model-sample-args.R b/man-roxygen/model-sample-args.R index 6704ca682..c7e2145c7 100644 --- a/man-roxygen/model-sample-args.R +++ b/man-roxygen/model-sample-args.R @@ -92,7 +92,7 @@ #' These diagnostics are also available after fitting. The #' [`$sampler_diagnostics()`][fit-method-sampler_diagnostics] method provides #' access the diagnostic values for each iteration and the -#' [`$diagnostic_summary()`][fit-method-diagnostic_summary] method provides +#' [`$diagnose_sampler()`][fit-method-diagnose_sampler] method provides #' summaries of the diagnostics and can regenerate the warning messages. #' #' Diagnostics like R-hat and effective sample size are not currently diff --git a/man/CmdStanMCMC.Rd b/man/CmdStanMCMC.Rd index c99734dcc..781c29db9 100644 --- a/man/CmdStanMCMC.Rd +++ b/man/CmdStanMCMC.Rd @@ -29,6 +29,7 @@ methods, all of which have their own (linked) documentation pages. \subsection{Summarize inferences and diagnostics}{\tabular{ll}{ \strong{Method} \tab \strong{Description} \cr \code{\link[=fit-method-summary]{$summary()}} \tab Run \code{\link[posterior:draws_summary]{posterior::summarise_draws()}}. \cr + \code{\link[=fit-method-diagnose_sampler]{$diagnose_sampler()}} \tab Get summaries of sampler diagnostics and warning messages. \cr \code{\link[=fit-method-cmdstan_summary]{$cmdstan_summary()}} \tab Run and print CmdStan's \code{bin/stansummary}. \cr \code{\link[=fit-method-cmdstan_summary]{$cmdstan_diagnose()}} \tab Run and print CmdStan's \code{bin/diagnose}. \cr \code{\link[=fit-method-loo]{$loo()}} \tab Run \code{\link[loo:loo]{loo::loo.array()}} for approximate LOO-CV \cr diff --git a/man/fit-method-diagnostic_summary.Rd b/man/fit-method-diagnose_sampler.Rd similarity index 79% rename from man/fit-method-diagnostic_summary.Rd rename to man/fit-method-diagnose_sampler.Rd index 1aaf2ce09..265410691 100644 --- a/man/fit-method-diagnostic_summary.Rd +++ b/man/fit-method-diagnose_sampler.Rd @@ -1,11 +1,11 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/fit.R -\name{fit-method-diagnostic_summary} -\alias{fit-method-diagnostic_summary} -\alias{diagnostic_summary} -\title{Warnings and summaries of sampler diagnostics} +\name{fit-method-diagnose_sampler} +\alias{fit-method-diagnose_sampler} +\alias{diagnose_sampler} +\title{Sampler diagnostic summaries and warnings} \usage{ -diagnostic_summary( +diagnose_sampler( diagnostics = c("divergences", "treedepth", "ebfmi"), quiet = FALSE ) @@ -24,8 +24,8 @@ displayed.} A list with as many named elements as \code{diagnostics} selected. The possible elements and their values are: \itemize{ -\item \code{"divergences"}: A vector of the number of divergences per chain. -\item \code{"max_treedepths"}: A vector of the number of times \code{max_treedepth} was hit per chain. +\item \code{"num_divergent"}: A vector of the number of divergences per chain. +\item \code{"num_max_treedepth"}: A vector of the number of times \code{max_treedepth} was hit per chain. \item \code{"ebfmi"}: A vector of E-BFMI values per chain. } } @@ -46,8 +46,8 @@ size are not handled by this method. Those diagnostics are provided via the \examples{ \dontrun{ fit <- cmdstanr_example("schools") -fit$diagnostic_summary() -fit$diagnostic_summary(quiet = TRUE) +fit$diagnose_sampler() +fit$diagnose_sampler(quiet = TRUE) } } diff --git a/man/fit-method-sampler_diagnostics.Rd b/man/fit-method-sampler_diagnostics.Rd index aef705747..5b48f4449 100644 --- a/man/fit-method-sampler_diagnostics.Rd +++ b/man/fit-method-sampler_diagnostics.Rd @@ -25,7 +25,9 @@ variable). The variables for Stan's default MCMC algorithm are } \description{ Extract the values of sampler diagnostics for each iteration and -chain of MCMC. +chain of MCMC. To instead get summaries of these diagnostics and associated +warning messages use the +\code{\link[=fit-method-diagnose_sampler]{$diagnose_sampler()}} method. } \examples{ \dontrun{ diff --git a/man/model-method-sample.Rd b/man/model-method-sample.Rd index 1fc377387..f35a59b71 100644 --- a/man/model-method-sample.Rd +++ b/man/model-method-sample.Rd @@ -255,7 +255,7 @@ and \code{"ebfmi"} (the default is to check all of them). These diagnostics are also available after fitting. The \code{\link[=fit-method-sampler_diagnostics]{$sampler_diagnostics()}} method provides access the diagnostic values for each iteration and the -\code{\link[=fit-method-diagnostic_summary]{$diagnostic_summary()}} method provides +\code{\link[=fit-method-diagnose_sampler]{$diagnose_sampler()}} method provides summaries of the diagnostics and can regenerate the warning messages. Diagnostics like R-hat and effective sample size are not currently diff --git a/man/model-method-sample_mpi.Rd b/man/model-method-sample_mpi.Rd index 95e708818..45d9c2bfd 100644 --- a/man/model-method-sample_mpi.Rd +++ b/man/model-method-sample_mpi.Rd @@ -234,7 +234,7 @@ and \code{"ebfmi"} (the default is to check all of them). These diagnostics are also available after fitting. The \code{\link[=fit-method-sampler_diagnostics]{$sampler_diagnostics()}} method provides access the diagnostic values for each iteration and the -\code{\link[=fit-method-diagnostic_summary]{$diagnostic_summary()}} method provides +\code{\link[=fit-method-diagnose_sampler]{$diagnose_sampler()}} method provides summaries of the diagnostics and can regenerate the warning messages. Diagnostics like R-hat and effective sample size are not currently diff --git a/tests/testthat/test-fit-mcmc.R b/tests/testthat/test-fit-mcmc.R index 31afa8d76..284878f01 100644 --- a/tests/testthat/test-fit-mcmc.R +++ b/tests/testthat/test-fit-mcmc.R @@ -341,10 +341,10 @@ test_that("draws() errors if invalid format", { ) }) -test_that("diagnostic_summary() works", { +test_that("diagnose_sampler() works", { fit <- suppressMessages(cmdstanr_example("schools")) expect_message( - diagnostics <- fit$diagnostic_summary(), + diagnostics <- fit$diagnose_sampler(), "transitions ended with a divergence" ) expect_equal( From 889b4389c54efddb5290603f693052f509edf2ef Mon Sep 17 00:00:00 2001 From: jgabry Date: Thu, 4 Nov 2021 11:58:35 -0600 Subject: [PATCH 32/44] doc edits --- R/fit.R | 7 +++---- man/fit-method-diagnose_sampler.Rd | 7 +++---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/R/fit.R b/R/fit.R index 7fcee1e57..4c96b7e8e 100644 --- a/R/fit.R +++ b/R/fit.R @@ -1131,10 +1131,9 @@ CmdStanMCMC$set("public", name = "sampler_diagnostics", value = sampler_diagnost #' @param diagnostics (character vector) One or more diagnostics to check. The #' currently supported diagnostics are `"divergences`, `"treedepth"`, and #' `"ebfmi`. -#' @param quiet (logical) Should messages about the diagnostics be displayed? -#' The values of the diagnostics are always returned but if `quiet = FALSE` -#' (the default) the warning messages about the diagnostics are also -#' displayed. +#' @param quiet (logical) Should warning messages about the diagnostics be +#' suppressed? The default is `FALSE`, in which case warning messages are +#' printed in addition to returning the values of the diagnostics. #' #' @return A list with as many named elements as `diagnostics` selected. The #' possible elements and their values are: diff --git a/man/fit-method-diagnose_sampler.Rd b/man/fit-method-diagnose_sampler.Rd index 265410691..b66bee56c 100644 --- a/man/fit-method-diagnose_sampler.Rd +++ b/man/fit-method-diagnose_sampler.Rd @@ -15,10 +15,9 @@ diagnose_sampler( currently supported diagnostics are \verb{"divergences}, \code{"treedepth"}, and \verb{"ebfmi}.} -\item{quiet}{(logical) Should messages about the diagnostics be displayed? -The values of the diagnostics are always returned but if \code{quiet = FALSE} -(the default) the warning messages about the diagnostics are also -displayed.} +\item{quiet}{(logical) Should warning messages about the diagnostics be +suppressed? The default is \code{FALSE}, in which case warning messages are +printed in addition to returning the values of the diagnostics.} } \value{ A list with as many named elements as \code{diagnostics} selected. The From b718ad647025ee5da8752c4f6f764b75a45435cc Mon Sep 17 00:00:00 2001 From: jgabry Date: Thu, 4 Nov 2021 14:41:12 -0600 Subject: [PATCH 33/44] fixes to get tests passing --- NEWS.md | 3 +++ R/fit.R | 21 +++++++++++---------- R/model.R | 2 +- R/utils.R | 19 +++++++++++++++++++ man/fit-method-diagnose_sampler.Rd | 7 ++----- tests/testthat/test-fit-mcmc.R | 16 +++++++++++++--- 6 files changed, 49 insertions(+), 19 deletions(-) diff --git a/NEWS.md b/NEWS.md index 7369dda5c..aea05aad5 100644 --- a/NEWS.md +++ b/NEWS.md @@ -65,6 +65,9 @@ These are just wrappers around the `$draws()` method provided for convenience. ( * New method `$diagnose_sampler()` that summarizes the sampler diagnostics (divergences, treedepth, ebfmi) and can regenerate the related warning messages. (#205) +* New `diagnostics` argument for the `$sample()` method to specify which +diagnostics are checked after sampling. Replaces `validate_csv` argument. (#205) + # cmdstanr 0.4.0 ### Bug fixes diff --git a/R/fit.R b/R/fit.R index 4c96b7e8e..8fcf8a704 100644 --- a/R/fit.R +++ b/R/fit.R @@ -857,13 +857,17 @@ CmdStanMCMC <- R6::R6Class( warning("No chains finished successfully. Unable to retrieve the fit.", call. = FALSE) } else { + # throw diagnostic warnings if user asked for them and if not fixed_param if (!is.null(self$runset$args$method_args$diagnostics)) { - diagnostics <- self$runset$method_args$diagnostics + diagnostics <- self$runset$args$method_args$diagnostics fixed_param <- runset$args$method_args$fixed_param - private$read_csv_( - variables = "", - sampler_diagnostics = if (!fixed_param) diagnostics else "" - ) + if (!fixed_param) { + # convert user friendly names to actual diagnostic names (e.g. divergences --> divergent__) + diagnostics_to_read <- convert_diagnostic_names(diagnostics) + } else { + diagnostics_to_read <- "" + } + private$read_csv_(variables = "", sampler_diagnostics = diagnostics_to_read) if (!fixed_param) { invisible(self$diagnose_sampler(diagnostics = diagnostics, quiet = FALSE)) } @@ -1120,13 +1124,10 @@ CmdStanMCMC$set("public", name = "sampler_diagnostics", value = sampler_diagnost #' chain use the [`$sampler_diagnostics()`][fit-method-sampler_diagnostics] #' method. #' -#' This method is similar to [`$cmdstan_diagnose()`][fit-method-cmdstan_summary] -#' but everything is done in \R (rather than calling CmdStan utilities) and -#' the summaries are returned as a list (not just printed). -#' #' Currently parameter-specific diagnostics like R-hat and effective sample #' size are not handled by this method. Those diagnostics are provided via the -#' [`$summary()`][fit-method-summary] method. +#' [`$summary()`][fit-method-summary] method (using +#' [posterior::summarize_draws()]). #' #' @param diagnostics (character vector) One or more diagnostics to check. The #' currently supported diagnostics are `"divergences`, `"treedepth"`, and diff --git a/R/model.R b/R/model.R index c586d8230..97bb3fc00 100644 --- a/R/model.R +++ b/R/model.R @@ -1021,7 +1021,7 @@ sample_mpi <- function(data = NULL, validate_csv = TRUE) { if (!is.null(validate_csv)) { - warning("'validate_csv' is deprecated. Please set 'diagnostics=NULL' instead.") + warning("'validate_csv' is deprecated. Please use 'diagnostics' instead.") if (is.logical(validate_csv)) { if (validate_csv) { diagnostics <- c("divergences", "treedepth", "ebfmi") diff --git a/R/utils.R b/R/utils.R index ae223228c..3610f4c05 100644 --- a/R/utils.R +++ b/R/utils.R @@ -324,6 +324,25 @@ available_diagnostics <- function() { c("divergences", "treedepth", "ebfmi") } +# in some places we need to convert user friendly names +# to the names used in the sampler diagnostics files: +# * ebfmi --> energy__ +# * divergences --> divergent__ +# * treedepth --> treedepth__ +convert_diagnostic_names <- function(diagnostics) { + diagnostic_names <- c() + if ("divergences" %in% diagnostics) { + diagnostic_names <- c(diagnostic_names, "divergent__") + } + if ("treedepth" %in% diagnostics) { + diagnostic_names <- c(diagnostic_names, "treedepth__") + } + if ("ebfmi" %in% diagnostics) { + diagnostic_names <- c(diagnostic_names, "energy__") + } + diagnostic_names +} + # draws formatting -------------------------------------------------------- as_draws_format_fun <- function(draws_format) { diff --git a/man/fit-method-diagnose_sampler.Rd b/man/fit-method-diagnose_sampler.Rd index b66bee56c..f8e265d23 100644 --- a/man/fit-method-diagnose_sampler.Rd +++ b/man/fit-method-diagnose_sampler.Rd @@ -34,13 +34,10 @@ the underlying values of the sampler diagnostics for each iteration and chain use the \code{\link[=fit-method-sampler_diagnostics]{$sampler_diagnostics()}} method. -This method is similar to \code{\link[=fit-method-cmdstan_summary]{$cmdstan_diagnose()}} -but everything is done in \R (rather than calling CmdStan utilities) and -the summaries are returned as a list (not just printed). - Currently parameter-specific diagnostics like R-hat and effective sample size are not handled by this method. Those diagnostics are provided via the -\code{\link[=fit-method-summary]{$summary()}} method. +\code{\link[=fit-method-summary]{$summary()}} method (using +\code{\link[posterior:draws_summary]{posterior::summarize_draws()}}). } \examples{ \dontrun{ diff --git a/tests/testthat/test-fit-mcmc.R b/tests/testthat/test-fit-mcmc.R index 284878f01..7d27aefdd 100644 --- a/tests/testthat/test-fit-mcmc.R +++ b/tests/testthat/test-fit-mcmc.R @@ -41,7 +41,7 @@ test_that("draws() stops for unkown variables", { test_that("draws() works when gradually adding variables", { skip_on_cran() fit <- testing_fit("logistic", method = "sample", refresh = 0, - validate_csv = TRUE, save_warmup = TRUE) + save_warmup = TRUE) draws_lp__ <- fit$draws(variables = c("lp__"), inc_warmup = TRUE) sampler_diagnostics <- fit$sampler_diagnostics(inc_warmup = TRUE) @@ -342,21 +342,31 @@ test_that("draws() errors if invalid format", { }) test_that("diagnose_sampler() works", { + # will have divergences fit <- suppressMessages(cmdstanr_example("schools")) + expect_message( diagnostics <- fit$diagnose_sampler(), "transitions ended with a divergence" ) expect_equal( - diagnostics$divergences, + diagnostics$num_divergent, suppressMessages(check_divergences(fit$sampler_diagnostics())) ) expect_equal( - diagnostics$max_treedepths, + diagnostics$num_max_treedepth, suppressMessages(check_max_treedepth(fit$sampler_diagnostics(), fit$metadata())) ) expect_equal( diagnostics$ebfmi, suppressMessages(check_ebfmi(fit$sampler_diagnostics())) ) + + # ebfmi not defined if iter < 3 + fit <- suppressWarnings(suppressMessages(cmdstanr_example("schools", iter_sampling = 2))) + expect_warning( + diagnostics <- fit$diagnose_sampler(), + "E-BFMI not computed" + ) + expect_equal(diagnostics$ebfmi, NA) }) From 6d8c06e91d0395a1c3601e55644cda6b6665550c Mon Sep 17 00:00:00 2001 From: jgabry Date: Thu, 4 Nov 2021 14:58:54 -0600 Subject: [PATCH 34/44] make sure treedepth warnings are tested too --- tests/testthat/resources/stan/.gitignore | 15 +++++++++++---- tests/testthat/test-fit-mcmc.R | 10 +++++++--- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/tests/testthat/resources/stan/.gitignore b/tests/testthat/resources/stan/.gitignore index 9d8ec1f05..c55f933eb 100644 --- a/tests/testthat/resources/stan/.gitignore +++ b/tests/testthat/resources/stan/.gitignore @@ -1,9 +1,16 @@ /bernoulli -/logistic -/chain_fails -/init_warnings -/bernoulli_threads +/bernoulli_external /bernoulli_fp /bernoulli_include +/bernoulli_log_lik /bernoulli_ppc +/bernoulli_threads +/chain_fails +/divide_real_by_two +/fail /info_message +/init_warnings +/logistic +/logistic_profiling +/schools + diff --git a/tests/testthat/test-fit-mcmc.R b/tests/testthat/test-fit-mcmc.R index 7d27aefdd..ea37c7c31 100644 --- a/tests/testthat/test-fit-mcmc.R +++ b/tests/testthat/test-fit-mcmc.R @@ -342,8 +342,8 @@ test_that("draws() errors if invalid format", { }) test_that("diagnose_sampler() works", { - # will have divergences - fit <- suppressMessages(cmdstanr_example("schools")) + # will have divergences and treedepth problems + fit <- suppressMessages(testing_fit("schools", max_treedepth = 3, seed = 123)) expect_message( diagnostics <- fit$diagnose_sampler(), @@ -353,6 +353,10 @@ test_that("diagnose_sampler() works", { diagnostics$num_divergent, suppressMessages(check_divergences(fit$sampler_diagnostics())) ) + expect_message( + diagnostics <- fit$diagnose_sampler(), + "transitions hit the maximum treedepth limit of 3" + ) expect_equal( diagnostics$num_max_treedepth, suppressMessages(check_max_treedepth(fit$sampler_diagnostics(), fit$metadata())) @@ -363,7 +367,7 @@ test_that("diagnose_sampler() works", { ) # ebfmi not defined if iter < 3 - fit <- suppressWarnings(suppressMessages(cmdstanr_example("schools", iter_sampling = 2))) + fit <- suppressWarnings(suppressMessages(testing_fit("schools", iter_sampling = 2))) expect_warning( diagnostics <- fit$diagnose_sampler(), "E-BFMI not computed" From 267af8fda2c59c6165982daaac625045349f3f6b Mon Sep 17 00:00:00 2001 From: jgabry Date: Thu, 4 Nov 2021 15:14:56 -0600 Subject: [PATCH 35/44] Update DESCRIPTION --- DESCRIPTION | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index a5dbc07ce..e242a7aea 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -12,7 +12,8 @@ Authors@R: person(given = "Mikhail", family = "Popov", role = "ctb"), person(given = "Mike", family = "Lawrence", role = "ctb"), person(given = c("William", "Michael"), family = "Landau", role = "ctb", - email = "will.landau@gmail.com", comment = c(ORCID = "0000-0003-1878-3253"))) + email = "will.landau@gmail.com", comment = c(ORCID = "0000-0003-1878-3253")), + person(given = "Jacob", family = "Socolar", role = "ctb")) Description: A lightweight interface to 'Stan' . The 'CmdStanR' interface is an alternative to 'RStan' that calls the command line interface for compilation and running algorithms instead of interfacing From 802819b1361fdafe15052e31c19483935e22ad0a Mon Sep 17 00:00:00 2001 From: jgabry Date: Fri, 5 Nov 2021 11:53:44 -0600 Subject: [PATCH 36/44] demonstrate new method in vignette --- vignettes/cmdstanr.Rmd | 72 ++++++++++++++++++++++++++++++++---------- 1 file changed, 55 insertions(+), 17 deletions(-) diff --git a/vignettes/cmdstanr.Rmd b/vignettes/cmdstanr.Rmd index 1814a6161..4bf8b3222 100644 --- a/vignettes/cmdstanr.Rmd +++ b/vignettes/cmdstanr.Rmd @@ -4,7 +4,7 @@ author: "Jonah Gabry and Rok Češnovar" output: rmarkdown::html_vignette: toc: true - toc_depth: 4 + toc_depth: 3 params: EVAL: !r identical(Sys.getenv("NOT_CRAN"), "true") vignette: > @@ -160,7 +160,7 @@ fit <- mod$sample( seed = 123, chains = 4, parallel_chains = 4, - refresh = 500 + refresh = 500 # print update every 500 iters ) ``` @@ -178,18 +178,28 @@ the most important methods. For a full list, follow this link to the ### Posterior summary statistics +#### Summaries from the posterior package + The [`$summary()`](https://mc-stan.org/cmdstanr/reference/fit-method-summary.html) -method calls `summarise_draws()` from the **posterior** package: +method calls `summarise_draws()` from the **posterior** package. The +first argument specifies the variables to summarize and any arguments +after that are passed on to `posterior::summarise_draws()` to specify +which summaries to compute, whether to use multiple cores, etc. ```{r summary} fit$summary() -fit$summary("theta", "mean", "sd") +fit$summary(variables = c("theta", "lp__"), "mean", "sd") # use a formula to summarize arbitrary functions, e.g. Pr(theta <= 0.5) fit$summary("theta", pr_lt_half = ~ mean(. <= 0.5)) ``` +#### CmdStan's stansummary utility + +CmdStan itself provides a `stansummary` utility that can be called using the +`$cmdstan_summary()` method. This method will print summaries but won't return +anything. ### Posterior draws @@ -238,38 +248,66 @@ mcmc_hist(fit$draws("theta")) ### Sampler diagnostics +#### Extracting diagnostic values for each iteration and chain + The [`$sampler_diagnostics()`](https://mc-stan.org/cmdstanr/reference/fit-method-sampler_diagnostics.html) method extracts the values of the sampler parameters (`treedepth__`, -`divergent__`, etc.) as a 3-D array (iteration x chain x variable). +`divergent__`, etc.) in formats supported by the **posterior** package. The +default is as a 3-D array (iteration x chain x variable). ```{r sampler_diagnostics} # this is a draws_array object from the posterior package str(fit$sampler_diagnostics()) -# convert to matrix or data frame using posterior package -diagnostics_df <- as_draws_df(fit$sampler_diagnostics()) -print(diagnostics_df) +# this is a draws_df object from the posterior package +str(fit$sampler_diagnostics(format = "df")) +``` + +#### Sampler diagnostic warnings and summaries + +The `$diagnose_sampler()` method will display any sampler diagnostic warnings and return a summary of diagnostics for each chain. + +```{r diagnose_sampler} +fit$diagnose_sampler() ``` -### CmdStan utilities +We see the number of divergences for each of the four chains, the number +of times the maximum treedepth was hit for each chain, and the E-BFMI +for each chain. + +In this case there were no warnings, so in order to demonstrate the warning +messages we'll use one of the CmdStanR example models that suffers from +divergences. + +```{r fit-with-warnings, results='hold'} +fit_with_warning <- cmdstanr_example("schools") +``` +After fitting there is a warning about divergences. We can also regenerate this warning message later using `fit$diagnose_sampler()`. -The [`$cmdstan_diagnose()`](https://mc-stan.org/cmdstanr/reference/fit-method-cmdstan_summary.html) -and [`$cmdstan_summary()`](https://mc-stan.org/cmdstanr/reference/fit-method-cmdstan_summary.html) -methods call CmdStan's `diagnose` and `stansummary` utilities. +```{r diagnose_sampler-with-warnings} +diagnostics <- fit_with_warning$diagnose_sampler() +print(diagnostics) -```{r summary-and-diagnose} -fit$cmdstan_diagnose() -fit$cmdstan_summary() +# number of divergences reported in warning is the sum of the per chain values +sum(diagnostics$num_divergent) ``` +#### CmdStan's diagnose utility + +CmdStan itself provides a `diagnose` utility that can be called using +the `$cmdstan_diagnose()` method. This method will print warnings but won't return anything. + ### Create a `stanfit` object -If you have RStan installed then it is also possible to create a `stanfit` +If you have RStan installed then it is also possible to create a `stanfit` object from the csv output files written by CmdStan. This can be done by using `rstan::read_stan_csv()` in combination with the `$output_files()` method of the -`CmdStanMCMC` object. +`CmdStanMCMC` object. This is only needed if you want to fit a model with +CmdStanR but already have a lot of post-processing code that assumes a `stanfit` +object. Otherwise we recommend using the post-processing functionality provided +by CmdStanR itself. ```{r stanfit, eval=FALSE} stanfit <- rstan::read_stan_csv(fit$output_files()) From ce64d9807508106e77e38827a08c8148281d7478 Mon Sep 17 00:00:00 2001 From: jgabry Date: Tue, 9 Nov 2021 09:28:15 -0700 Subject: [PATCH 37/44] diagnose_sampler -> diagnostic_summary --- NEWS.md | 2 +- R/args.R | 2 +- R/fit.R | 30 ++++++++----------- R/utils.R | 8 ++--- man-roxygen/model-sample-args.R | 2 +- man/CmdStanMCMC.Rd | 2 +- ...er.Rd => fit-method-diagnostic_summary.Rd} | 16 +++++----- man/fit-method-sampler_diagnostics.Rd | 2 +- man/model-method-sample.Rd | 2 +- man/model-method-sample_mpi.Rd | 2 +- tests/testthat/test-fit-mcmc.R | 8 ++--- vignettes/cmdstanr.Rmd | 12 ++++---- 12 files changed, 42 insertions(+), 46 deletions(-) rename man/{fit-method-diagnose_sampler.Rd => fit-method-diagnostic_summary.Rd} (82%) diff --git a/NEWS.md b/NEWS.md index 7a46d9cf7..4d1390dc5 100644 --- a/NEWS.md +++ b/NEWS.md @@ -62,7 +62,7 @@ These are just wrappers around the `$draws()` method provided for convenience. ( * Added E-BFMI checks that run automatically post sampling. (#500, @jsocolar) -* New method `$diagnose_sampler()` that summarizes the sampler diagnostics +* New method `$diagnostic_summary()` that summarizes the sampler diagnostics (divergences, treedepth, ebfmi) and can regenerate the related warning messages. (#205) * New `diagnostics` argument for the `$sample()` method to specify which diff --git a/R/args.R b/R/args.R index a51255dc6..91627f653 100644 --- a/R/args.R +++ b/R/args.R @@ -642,7 +642,7 @@ validate_sample_args <- function(self, num_procs) { checkmate::assert_character(self$diagnostics, null.ok = TRUE, any.missing = FALSE) if (!is.null(self$diagnostics)) { - checkmate::assert_subset(self$diagnostics, empty.ok = FALSE, choices = available_diagnostics()) + checkmate::assert_subset(self$diagnostics, empty.ok = FALSE, choices = available_hmc_diagnostics()) } invisible(TRUE) diff --git a/R/fit.R b/R/fit.R index 8fcf8a704..d21e8fa31 100644 --- a/R/fit.R +++ b/R/fit.R @@ -824,7 +824,7 @@ CmdStanFit$set("public", name = "code", value = code) #' |**Method**|**Description**| #' |:----------|:---------------| #' [`$summary()`][fit-method-summary] | Run [`posterior::summarise_draws()`][posterior::draws_summary]. | -#' [`$diagnose_sampler()`][fit-method-diagnose_sampler] | Get summaries of sampler diagnostics and warning messages. | +#' [`$diagnostic_summary()`][fit-method-diagnostic_summary] | Get summaries of sampler diagnostics and warning messages. | #' [`$cmdstan_summary()`][fit-method-cmdstan_summary] | Run and print CmdStan's `bin/stansummary`. | #' [`$cmdstan_diagnose()`][fit-method-cmdstan_summary] | Run and print CmdStan's `bin/diagnose`. | #' [`$loo()`][fit-method-loo] | Run [loo::loo.array()] for approximate LOO-CV | @@ -863,13 +863,13 @@ CmdStanMCMC <- R6::R6Class( fixed_param <- runset$args$method_args$fixed_param if (!fixed_param) { # convert user friendly names to actual diagnostic names (e.g. divergences --> divergent__) - diagnostics_to_read <- convert_diagnostic_names(diagnostics) + diagnostics_to_read <- convert_hmc_diagnostic_names(diagnostics) } else { diagnostics_to_read <- "" } private$read_csv_(variables = "", sampler_diagnostics = diagnostics_to_read) if (!fixed_param) { - invisible(self$diagnose_sampler(diagnostics = diagnostics, quiet = FALSE)) + invisible(self$diagnostic_summary(diagnostics = diagnostics, quiet = FALSE)) } } } @@ -1056,7 +1056,7 @@ CmdStanMCMC$set("public", name = "loo", value = loo) #' @description Extract the values of sampler diagnostics for each iteration and #' chain of MCMC. To instead get summaries of these diagnostics and associated #' warning messages use the -#' [`$diagnose_sampler()`][fit-method-diagnose_sampler] method. +#' [`$diagnostic_summary()`][fit-method-diagnostic_summary] method. #' #' @param inc_warmup (logical) Should warmup draws be included? Defaults to `FALSE`. #' @param format (string) The draws format to return. See @@ -1117,16 +1117,16 @@ CmdStanMCMC$set("public", name = "sampler_diagnostics", value = sampler_diagnost #' Sampler diagnostic summaries and warnings #' -#' @name fit-method-diagnose_sampler -#' @aliases diagnose_sampler +#' @name fit-method-diagnostic_summary +#' @aliases diagnostic_summary #' @description Warnings and summaries of sampler diagnostics. To instead get #' the underlying values of the sampler diagnostics for each iteration and #' chain use the [`$sampler_diagnostics()`][fit-method-sampler_diagnostics] #' method. #' #' Currently parameter-specific diagnostics like R-hat and effective sample -#' size are not handled by this method. Those diagnostics are provided via the -#' [`$summary()`][fit-method-summary] method (using +#' size are _not_ handled by this method. Those diagnostics are provided via +#' the [`$summary()`][fit-method-summary] method (using #' [posterior::summarize_draws()]). #' #' @param diagnostics (character vector) One or more diagnostics to check. The @@ -1148,18 +1148,14 @@ CmdStanMCMC$set("public", name = "sampler_diagnostics", value = sampler_diagnost #' @examples #' \dontrun{ #' fit <- cmdstanr_example("schools") -#' fit$diagnose_sampler() -#' fit$diagnose_sampler(quiet = TRUE) +#' fit$diagnostic_summary() +#' fit$diagnostic_summary(quiet = TRUE) #' } #' -diagnose_sampler <- function(diagnostics = c("divergences", "treedepth", "ebfmi"), quiet = FALSE) { - if (is.null(private$sampler_diagnostics_) && - !length(self$output_files(include_failed = FALSE))) { - stop("No chains finished successfully. Unable to retrieve the sampler diagnostics.", call. = FALSE) - } +diagnostic_summary <- function(diagnostics = c("divergences", "treedepth", "ebfmi"), quiet = FALSE) { diagnostics <- match.arg( diagnostics, - choices = available_diagnostics(), + choices = available_hmc_diagnostics(), several.ok = TRUE ) post_warmup_sampler_diagnostics <- self$sampler_diagnostics(inc_warmup = FALSE) @@ -1190,7 +1186,7 @@ diagnose_sampler <- function(diagnostics = c("divergences", "treedepth", "ebfmi" } out } -CmdStanMCMC$set("public", name = "diagnose_sampler", value = diagnose_sampler) +CmdStanMCMC$set("public", name = "diagnostic_summary", value = diagnostic_summary) #' Extract inverse metric (mass matrix) after MCMC diff --git a/R/utils.R b/R/utils.R index d205d738d..b41681623 100644 --- a/R/utils.R +++ b/R/utils.R @@ -237,7 +237,7 @@ set_num_threads <- function(num_threads) { } -# convergence checks ------------------------------------------------------ +# hmc diagnostics ------------------------------------------------------ check_divergences <- function(post_warmup_sampler_diagnostics) { num_divergences_per_chain <- NULL if (!is.null(post_warmup_sampler_diagnostics)) { @@ -318,9 +318,9 @@ check_ebfmi <- function(post_warmup_sampler_diagnostics, threshold = 0.2) { invisible(unname(efbmi_per_chain)) } -# used in various places (e.g., fit$diagnose_sampler() and validate_sample_args()) +# used in various places (e.g., fit$diagnostic_summary() and validate_sample_args()) # to validate the selected diagnostics -available_diagnostics <- function() { +available_hmc_diagnostics <- function() { c("divergences", "treedepth", "ebfmi") } @@ -329,7 +329,7 @@ available_diagnostics <- function() { # * ebfmi --> energy__ # * divergences --> divergent__ # * treedepth --> treedepth__ -convert_diagnostic_names <- function(diagnostics) { +convert_hmc_diagnostic_names <- function(diagnostics) { diagnostic_names <- c() if ("divergences" %in% diagnostics) { diagnostic_names <- c(diagnostic_names, "divergent__") diff --git a/man-roxygen/model-sample-args.R b/man-roxygen/model-sample-args.R index c7e2145c7..6704ca682 100644 --- a/man-roxygen/model-sample-args.R +++ b/man-roxygen/model-sample-args.R @@ -92,7 +92,7 @@ #' These diagnostics are also available after fitting. The #' [`$sampler_diagnostics()`][fit-method-sampler_diagnostics] method provides #' access the diagnostic values for each iteration and the -#' [`$diagnose_sampler()`][fit-method-diagnose_sampler] method provides +#' [`$diagnostic_summary()`][fit-method-diagnostic_summary] method provides #' summaries of the diagnostics and can regenerate the warning messages. #' #' Diagnostics like R-hat and effective sample size are not currently diff --git a/man/CmdStanMCMC.Rd b/man/CmdStanMCMC.Rd index 781c29db9..950beffca 100644 --- a/man/CmdStanMCMC.Rd +++ b/man/CmdStanMCMC.Rd @@ -29,7 +29,7 @@ methods, all of which have their own (linked) documentation pages. \subsection{Summarize inferences and diagnostics}{\tabular{ll}{ \strong{Method} \tab \strong{Description} \cr \code{\link[=fit-method-summary]{$summary()}} \tab Run \code{\link[posterior:draws_summary]{posterior::summarise_draws()}}. \cr - \code{\link[=fit-method-diagnose_sampler]{$diagnose_sampler()}} \tab Get summaries of sampler diagnostics and warning messages. \cr + \code{\link[=fit-method-diagnostic_summary]{$diagnostic_summary()}} \tab Get summaries of sampler diagnostics and warning messages. \cr \code{\link[=fit-method-cmdstan_summary]{$cmdstan_summary()}} \tab Run and print CmdStan's \code{bin/stansummary}. \cr \code{\link[=fit-method-cmdstan_summary]{$cmdstan_diagnose()}} \tab Run and print CmdStan's \code{bin/diagnose}. \cr \code{\link[=fit-method-loo]{$loo()}} \tab Run \code{\link[loo:loo]{loo::loo.array()}} for approximate LOO-CV \cr diff --git a/man/fit-method-diagnose_sampler.Rd b/man/fit-method-diagnostic_summary.Rd similarity index 82% rename from man/fit-method-diagnose_sampler.Rd rename to man/fit-method-diagnostic_summary.Rd index f8e265d23..dddadbff4 100644 --- a/man/fit-method-diagnose_sampler.Rd +++ b/man/fit-method-diagnostic_summary.Rd @@ -1,11 +1,11 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/fit.R -\name{fit-method-diagnose_sampler} -\alias{fit-method-diagnose_sampler} -\alias{diagnose_sampler} +\name{fit-method-diagnostic_summary} +\alias{fit-method-diagnostic_summary} +\alias{diagnostic_summary} \title{Sampler diagnostic summaries and warnings} \usage{ -diagnose_sampler( +diagnostic_summary( diagnostics = c("divergences", "treedepth", "ebfmi"), quiet = FALSE ) @@ -35,15 +35,15 @@ chain use the \code{\link[=fit-method-sampler_diagnostics]{$sampler_diagnostics( method. Currently parameter-specific diagnostics like R-hat and effective sample -size are not handled by this method. Those diagnostics are provided via the -\code{\link[=fit-method-summary]{$summary()}} method (using +size are \emph{not} handled by this method. Those diagnostics are provided via +the \code{\link[=fit-method-summary]{$summary()}} method (using \code{\link[posterior:draws_summary]{posterior::summarize_draws()}}). } \examples{ \dontrun{ fit <- cmdstanr_example("schools") -fit$diagnose_sampler() -fit$diagnose_sampler(quiet = TRUE) +fit$diagnostic_summary() +fit$diagnostic_summary(quiet = TRUE) } } diff --git a/man/fit-method-sampler_diagnostics.Rd b/man/fit-method-sampler_diagnostics.Rd index 5b48f4449..35692afaf 100644 --- a/man/fit-method-sampler_diagnostics.Rd +++ b/man/fit-method-sampler_diagnostics.Rd @@ -27,7 +27,7 @@ variable). The variables for Stan's default MCMC algorithm are Extract the values of sampler diagnostics for each iteration and chain of MCMC. To instead get summaries of these diagnostics and associated warning messages use the -\code{\link[=fit-method-diagnose_sampler]{$diagnose_sampler()}} method. +\code{\link[=fit-method-diagnostic_summary]{$diagnostic_summary()}} method. } \examples{ \dontrun{ diff --git a/man/model-method-sample.Rd b/man/model-method-sample.Rd index f35a59b71..1fc377387 100644 --- a/man/model-method-sample.Rd +++ b/man/model-method-sample.Rd @@ -255,7 +255,7 @@ and \code{"ebfmi"} (the default is to check all of them). These diagnostics are also available after fitting. The \code{\link[=fit-method-sampler_diagnostics]{$sampler_diagnostics()}} method provides access the diagnostic values for each iteration and the -\code{\link[=fit-method-diagnose_sampler]{$diagnose_sampler()}} method provides +\code{\link[=fit-method-diagnostic_summary]{$diagnostic_summary()}} method provides summaries of the diagnostics and can regenerate the warning messages. Diagnostics like R-hat and effective sample size are not currently diff --git a/man/model-method-sample_mpi.Rd b/man/model-method-sample_mpi.Rd index 45d9c2bfd..95e708818 100644 --- a/man/model-method-sample_mpi.Rd +++ b/man/model-method-sample_mpi.Rd @@ -234,7 +234,7 @@ and \code{"ebfmi"} (the default is to check all of them). These diagnostics are also available after fitting. The \code{\link[=fit-method-sampler_diagnostics]{$sampler_diagnostics()}} method provides access the diagnostic values for each iteration and the -\code{\link[=fit-method-diagnose_sampler]{$diagnose_sampler()}} method provides +\code{\link[=fit-method-diagnostic_summary]{$diagnostic_summary()}} method provides summaries of the diagnostics and can regenerate the warning messages. Diagnostics like R-hat and effective sample size are not currently diff --git a/tests/testthat/test-fit-mcmc.R b/tests/testthat/test-fit-mcmc.R index ea37c7c31..bf2d24bf3 100644 --- a/tests/testthat/test-fit-mcmc.R +++ b/tests/testthat/test-fit-mcmc.R @@ -341,12 +341,12 @@ test_that("draws() errors if invalid format", { ) }) -test_that("diagnose_sampler() works", { +test_that("diagnostic_summary() works", { # will have divergences and treedepth problems fit <- suppressMessages(testing_fit("schools", max_treedepth = 3, seed = 123)) expect_message( - diagnostics <- fit$diagnose_sampler(), + diagnostics <- fit$diagnostic_summary(), "transitions ended with a divergence" ) expect_equal( @@ -354,7 +354,7 @@ test_that("diagnose_sampler() works", { suppressMessages(check_divergences(fit$sampler_diagnostics())) ) expect_message( - diagnostics <- fit$diagnose_sampler(), + diagnostics <- fit$diagnostic_summary(), "transitions hit the maximum treedepth limit of 3" ) expect_equal( @@ -369,7 +369,7 @@ test_that("diagnose_sampler() works", { # ebfmi not defined if iter < 3 fit <- suppressWarnings(suppressMessages(testing_fit("schools", iter_sampling = 2))) expect_warning( - diagnostics <- fit$diagnose_sampler(), + diagnostics <- fit$diagnostic_summary(), "E-BFMI not computed" ) expect_equal(diagnostics$ebfmi, NA) diff --git a/vignettes/cmdstanr.Rmd b/vignettes/cmdstanr.Rmd index 4bf8b3222..60d0867aa 100644 --- a/vignettes/cmdstanr.Rmd +++ b/vignettes/cmdstanr.Rmd @@ -266,10 +266,10 @@ str(fit$sampler_diagnostics(format = "df")) #### Sampler diagnostic warnings and summaries -The `$diagnose_sampler()` method will display any sampler diagnostic warnings and return a summary of diagnostics for each chain. +The `$diagnostic_summary()` method will display any sampler diagnostic warnings and return a summary of diagnostics for each chain. -```{r diagnose_sampler} -fit$diagnose_sampler() +```{r diagnostic_summary} +fit$diagnostic_summary() ``` We see the number of divergences for each of the four chains, the number @@ -283,10 +283,10 @@ divergences. ```{r fit-with-warnings, results='hold'} fit_with_warning <- cmdstanr_example("schools") ``` -After fitting there is a warning about divergences. We can also regenerate this warning message later using `fit$diagnose_sampler()`. +After fitting there is a warning about divergences. We can also regenerate this warning message later using `fit$diagnostic_summary()`. -```{r diagnose_sampler-with-warnings} -diagnostics <- fit_with_warning$diagnose_sampler() +```{r diagnostic_summary-with-warnings} +diagnostics <- fit_with_warning$diagnostic_summary() print(diagnostics) # number of divergences reported in warning is the sum of the per chain values From cdf12ed7f1d87b7efdb4b38afb57cd72c2fc1e46 Mon Sep 17 00:00:00 2001 From: jgabry Date: Tue, 9 Nov 2021 10:03:58 -0700 Subject: [PATCH 38/44] cleanup --- R/fit.R | 28 +++++++++++++--------------- R/run.R | 2 +- R/utils.R | 7 +++++-- tests/testthat/test-fit-mcmc.R | 3 +++ tests/testthat/test-model-sample.R | 1 + 5 files changed, 23 insertions(+), 18 deletions(-) diff --git a/R/fit.R b/R/fit.R index d21e8fa31..694f04d98 100644 --- a/R/fit.R +++ b/R/fit.R @@ -857,20 +857,15 @@ CmdStanMCMC <- R6::R6Class( warning("No chains finished successfully. Unable to retrieve the fit.", call. = FALSE) } else { - # throw diagnostic warnings if user asked for them and if not fixed_param - if (!is.null(self$runset$args$method_args$diagnostics)) { + if (runset$args$method_args$fixed_param) { + private$read_csv_(variables = "", sampler_diagnostics = "") + } else { diagnostics <- self$runset$args$method_args$diagnostics - fixed_param <- runset$args$method_args$fixed_param - if (!fixed_param) { - # convert user friendly names to actual diagnostic names (e.g. divergences --> divergent__) - diagnostics_to_read <- convert_hmc_diagnostic_names(diagnostics) - } else { - diagnostics_to_read <- "" - } - private$read_csv_(variables = "", sampler_diagnostics = diagnostics_to_read) - if (!fixed_param) { - invisible(self$diagnostic_summary(diagnostics = diagnostics, quiet = FALSE)) - } + private$read_csv_( + variables = "", + sampler_diagnostics = convert_hmc_diagnostic_names(diagnostics) + ) + invisible(self$diagnostic_summary(diagnostics, quiet = FALSE)) } } }, @@ -1153,13 +1148,16 @@ CmdStanMCMC$set("public", name = "sampler_diagnostics", value = sampler_diagnost #' } #' diagnostic_summary <- function(diagnostics = c("divergences", "treedepth", "ebfmi"), quiet = FALSE) { + out <- list() + if (is.null(diagnostics) || identical(diagnostics, "")) { + return(out) + } diagnostics <- match.arg( diagnostics, - choices = available_hmc_diagnostics(), + choices = c(available_hmc_diagnostics()), several.ok = TRUE ) post_warmup_sampler_diagnostics <- self$sampler_diagnostics(inc_warmup = FALSE) - out <- list() if ("divergences" %in% diagnostics) { if (quiet) { divergences <- suppressMessages(check_divergences(post_warmup_sampler_diagnostics)) diff --git a/R/run.R b/R/run.R index 361f55037..e0fd3855b 100644 --- a/R/run.R +++ b/R/run.R @@ -919,7 +919,7 @@ CmdStanMCMCProcs <- R6::R6Class( "seconds.\n") cat("Total execution time:", format(round(self$total_time(), 1), nsmall = 1), - "seconds.\n") + "seconds.\n\n") } else if (num_failed == num_chains) { warning("All chains finished unexpectedly! Use the $output(chain_id) method for more information.\n", call. = FALSE) warning("Use read_cmdstan_csv() to read the results of the failed chains.", diff --git a/R/utils.R b/R/utils.R index b41681623..004bb581c 100644 --- a/R/utils.R +++ b/R/utils.R @@ -248,7 +248,7 @@ check_divergences <- function(post_warmup_sampler_diagnostics) { if (!is.na(num_divergences) && num_divergences > 0) { percentage_divergences <- 100 * num_divergences / num_draws message( - "\nWarning: ", num_divergences, " of ", num_draws, + "Warning: ", num_divergences, " of ", num_draws, " (", (format(round(percentage_divergences, 0), nsmall = 1)), "%)", " transitions ended with a divergence.\n", "This may indicate insufficient exploration of the posterior distribution.\n", @@ -272,7 +272,7 @@ check_max_treedepth <- function(post_warmup_sampler_diagnostics, metadata) { if (!is.na(num_max_treedepths) && num_max_treedepths > 0) { percentage_max_treedepths <- 100 * num_max_treedepths / num_draws message( - num_max_treedepths, " of ", num_draws, " (", (format(round(percentage_max_treedepths, 0), nsmall = 1)), "%)", + "Warning: ", num_max_treedepths, " of ", num_draws, " (", (format(round(percentage_max_treedepths, 0), nsmall = 1)), "%)", " transitions hit the maximum treedepth limit of ", metadata$max_treedepth, " or 2^", metadata$max_treedepth, "-1 leapfrog steps.\n", "Trajectories that are prematurely terminated due to this limit will result in slow exploration.\n", @@ -340,6 +340,9 @@ convert_hmc_diagnostic_names <- function(diagnostics) { if ("ebfmi" %in% diagnostics) { diagnostic_names <- c(diagnostic_names, "energy__") } + if (length(diagnostic_names) == 0) { + diagnostic_names <- "" + } diagnostic_names } diff --git a/tests/testthat/test-fit-mcmc.R b/tests/testthat/test-fit-mcmc.R index bf2d24bf3..aae8386f3 100644 --- a/tests/testthat/test-fit-mcmc.R +++ b/tests/testthat/test-fit-mcmc.R @@ -373,4 +373,7 @@ test_that("diagnostic_summary() works", { "E-BFMI not computed" ) expect_equal(diagnostics$ebfmi, NA) + + expect_equal(fit$diagnostic_summary(""), list()) + expect_equal(fit$diagnostic_summary(NULL), list()) }) diff --git a/tests/testthat/test-model-sample.R b/tests/testthat/test-model-sample.R index e817a30fb..e53d45e7a 100644 --- a/tests/testthat/test-model-sample.R +++ b/tests/testthat/test-model-sample.R @@ -333,3 +333,4 @@ generated quantities { expect_null(fit$sampler_diagnostics()) expect_equal(posterior::variables(fit$draws()), "y") }) + From ac673dcfb8bdc91625e5da266569e748b8eebe4a Mon Sep 17 00:00:00 2001 From: jgabry Date: Tue, 9 Nov 2021 10:12:02 -0700 Subject: [PATCH 39/44] minor doc edit --- R/fit.R | 2 +- man/fit-method-diagnostic_summary.Rd | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/fit.R b/R/fit.R index 694f04d98..d06fbcc6f 100644 --- a/R/fit.R +++ b/R/fit.R @@ -1126,7 +1126,7 @@ CmdStanMCMC$set("public", name = "sampler_diagnostics", value = sampler_diagnost #' #' @param diagnostics (character vector) One or more diagnostics to check. The #' currently supported diagnostics are `"divergences`, `"treedepth"`, and -#' `"ebfmi`. +#' `"ebfmi`. The default is to check all of them. #' @param quiet (logical) Should warning messages about the diagnostics be #' suppressed? The default is `FALSE`, in which case warning messages are #' printed in addition to returning the values of the diagnostics. diff --git a/man/fit-method-diagnostic_summary.Rd b/man/fit-method-diagnostic_summary.Rd index dddadbff4..e6d1909a3 100644 --- a/man/fit-method-diagnostic_summary.Rd +++ b/man/fit-method-diagnostic_summary.Rd @@ -13,7 +13,7 @@ diagnostic_summary( \arguments{ \item{diagnostics}{(character vector) One or more diagnostics to check. The currently supported diagnostics are \verb{"divergences}, \code{"treedepth"}, and -\verb{"ebfmi}.} +\verb{"ebfmi}. The default is to check all of them.} \item{quiet}{(logical) Should warning messages about the diagnostics be suppressed? The default is \code{FALSE}, in which case warning messages are From 871a222133068ca47d3bc6aa2ba98c064ee041ae Mon Sep 17 00:00:00 2001 From: jgabry Date: Wed, 10 Nov 2021 11:24:33 -0700 Subject: [PATCH 40/44] Update model-method-sample.Rd --- man/model-method-sample.Rd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/man/model-method-sample.Rd b/man/model-method-sample.Rd index 6ad68731f..6996c0f51 100644 --- a/man/model-method-sample.Rd +++ b/man/model-method-sample.Rd @@ -241,7 +241,7 @@ Disable if you wish to silence these messages, but this is not usually recommended unless you are very confident that the model is correct up to numerical error. If the messages are silenced then the \code{\link[=fit-method-output]{$output()}} method of the resulting fit object can be -used to display all the silenced messages.} +used to display the silenced messages.} \item{diagnostics}{(character vector) The diagnostics to automatically check and warn about after sampling. Setting this to an empty string \code{""} or From 84f60bf3b95be8868dc7e852afec01a8509dd14a Mon Sep 17 00:00:00 2001 From: jgabry Date: Wed, 10 Nov 2021 11:42:33 -0700 Subject: [PATCH 41/44] minor doc edit --- R/fit.R | 2 +- man-roxygen/model-sample-args.R | 2 +- man/model-method-sample.Rd | 2 +- man/model-method-sample_mpi.Rd | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/R/fit.R b/R/fit.R index d06fbcc6f..af7d5664c 100644 --- a/R/fit.R +++ b/R/fit.R @@ -1154,7 +1154,7 @@ diagnostic_summary <- function(diagnostics = c("divergences", "treedepth", "ebfm } diagnostics <- match.arg( diagnostics, - choices = c(available_hmc_diagnostics()), + choices = available_hmc_diagnostics(), several.ok = TRUE ) post_warmup_sampler_diagnostics <- self$sampler_diagnostics(inc_warmup = FALSE) diff --git a/man-roxygen/model-sample-args.R b/man-roxygen/model-sample-args.R index eadce3636..bbaa0ef3d 100644 --- a/man-roxygen/model-sample-args.R +++ b/man-roxygen/model-sample-args.R @@ -95,7 +95,7 @@ #' [`$diagnostic_summary()`][fit-method-diagnostic_summary] method provides #' summaries of the diagnostics and can regenerate the warning messages. #' -#' Diagnostics like R-hat and effective sample size are not currently +#' Diagnostics like R-hat and effective sample size are _not_ currently #' available via the `diagnostics` argument but can be checked after fitting #' using the [`$summary()`][fit-method-summary] method. #' diff --git a/man/model-method-sample.Rd b/man/model-method-sample.Rd index 6996c0f51..42e4cedb0 100644 --- a/man/model-method-sample.Rd +++ b/man/model-method-sample.Rd @@ -257,7 +257,7 @@ access the diagnostic values for each iteration and the \code{\link[=fit-method-diagnostic_summary]{$diagnostic_summary()}} method provides summaries of the diagnostics and can regenerate the warning messages. -Diagnostics like R-hat and effective sample size are not currently +Diagnostics like R-hat and effective sample size are \emph{not} currently available via the \code{diagnostics} argument but can be checked after fitting using the \code{\link[=fit-method-summary]{$summary()}} method.} diff --git a/man/model-method-sample_mpi.Rd b/man/model-method-sample_mpi.Rd index e0ab2cd5e..c8787e3cf 100644 --- a/man/model-method-sample_mpi.Rd +++ b/man/model-method-sample_mpi.Rd @@ -235,7 +235,7 @@ access the diagnostic values for each iteration and the \code{\link[=fit-method-diagnostic_summary]{$diagnostic_summary()}} method provides summaries of the diagnostics and can regenerate the warning messages. -Diagnostics like R-hat and effective sample size are not currently +Diagnostics like R-hat and effective sample size are \emph{not} currently available via the \code{diagnostics} argument but can be checked after fitting using the \code{\link[=fit-method-summary]{$summary()}} method.} From 0763bfef53492afc1dd179b7bb352573ec476543 Mon Sep 17 00:00:00 2001 From: jgabry Date: Thu, 11 Nov 2021 17:53:50 -0700 Subject: [PATCH 42/44] use diagnostic_summary for as_cmdstan_fit too --- R/csv.R | 11 ++++++----- tests/testthat/test-csv.R | 17 +++++++++++++++++ 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/R/csv.R b/R/csv.R index cc0611eae..9be121992 100644 --- a/R/csv.R +++ b/R/csv.R @@ -444,11 +444,6 @@ CmdStanMCMC_CSV <- R6::R6Class( inherit = CmdStanMCMC, public = list( initialize = function(csv_contents, files, check_diagnostics = TRUE) { - if (check_diagnostics) { - check_divergences(csv_contents$post_warmup_sampler_diagnostics) - check_max_treedepth(csv_contents$post_warmup_sampler_diagnostics, csv_contents$metadata) - check_ebfmi(csv_contents$post_warmup_sampler_diagnostics) - } private$output_files_ <- files private$metadata_ <- csv_contents$metadata private$time_ <- csv_contents$time @@ -457,6 +452,10 @@ CmdStanMCMC_CSV <- R6::R6Class( private$warmup_sampler_diagnostics_ <- csv_contents$warmup_sampler_diagnostics private$warmup_draws_ <- csv_contents$warmup_draws private$draws_ <- csv_contents$post_warmup_draws + if (check_diagnostics) { + invisible(self$diagnostic_summary()) + } + invisible(self) }, # override some methods so they work without a CmdStanRun object output_files = function(...) { @@ -482,6 +481,7 @@ CmdStanMLE_CSV <- R6::R6Class( private$output_files_ <- files private$draws_ <- csv_contents$point_estimates private$metadata_ <- csv_contents$metadata + invisible(self) }, output_files = function(...) { private$output_files_ @@ -497,6 +497,7 @@ CmdStanVB_CSV <- R6::R6Class( private$output_files_ <- files private$draws_ <- csv_contents$draws private$metadata_ <- csv_contents$metadata + invisible(self) }, output_files = function(...) { private$output_files_ diff --git a/tests/testthat/test-csv.R b/tests/testthat/test-csv.R index df74d73cc..d22218242 100644 --- a/tests/testthat/test-csv.R +++ b/tests/testthat/test-csv.R @@ -528,6 +528,23 @@ test_that("as_cmdstan_fit creates fitted model objects from csv", { } }) +test_that("as_cmdstan_fit can check MCMC diagnostics", { + fit_schools <- suppressMessages( + testing_fit("schools", chains = 2, + adapt_delta = 0.5, max_treedepth = 4, + show_messages = FALSE) + ) + expect_message( + as_cmdstan_fit(fit$output_files()), + "transitions ended with a divergence" + ) + expect_message( + as_cmdstan_fit(fit$output_files()), + "transitions hit the maximum treedepth" + ) + expect_silent(as_cmdstan_fit(fit$output_files(), check_diagnostics = FALSE)) +}) + test_that("read_cmdstan_csv reads seed correctly", { opt <- read_cmdstan_csv(fit_bernoulli_optimize$output_files()) vi <- read_cmdstan_csv(fit_bernoulli_variational$output_files()) From 33fc2e6826ec57d231effdbca6b35e33a7903994 Mon Sep 17 00:00:00 2001 From: jgabry Date: Thu, 11 Nov 2021 18:43:36 -0700 Subject: [PATCH 43/44] Update test-csv.R --- tests/testthat/test-csv.R | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/testthat/test-csv.R b/tests/testthat/test-csv.R index d22218242..4c1215451 100644 --- a/tests/testthat/test-csv.R +++ b/tests/testthat/test-csv.R @@ -535,14 +535,14 @@ test_that("as_cmdstan_fit can check MCMC diagnostics", { show_messages = FALSE) ) expect_message( - as_cmdstan_fit(fit$output_files()), + as_cmdstan_fit(fit_schools$output_files()), "transitions ended with a divergence" ) expect_message( - as_cmdstan_fit(fit$output_files()), + as_cmdstan_fit(fit_schools$output_files()), "transitions hit the maximum treedepth" ) - expect_silent(as_cmdstan_fit(fit$output_files(), check_diagnostics = FALSE)) + expect_silent(as_cmdstan_fit(fit_schools$output_files(), check_diagnostics = FALSE)) }) test_that("read_cmdstan_csv reads seed correctly", { From 1c458abe8e98c170a7c3450298272d66e5abaf6b Mon Sep 17 00:00:00 2001 From: jgabry Date: Thu, 10 Mar 2022 16:53:27 -0700 Subject: [PATCH 44/44] point to new warnings webpage closes #505 --- R/utils.R | 18 +++++------------- tests/testthat/test-utils.R | 10 +++++----- 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/R/utils.R b/R/utils.R index f5e55818f..ee9216675 100644 --- a/R/utils.R +++ b/R/utils.R @@ -251,11 +251,7 @@ check_divergences <- function(post_warmup_sampler_diagnostics) { "Warning: ", num_divergences, " of ", num_draws, " (", (base::format(round(percentage_divergences, 0), nsmall = 1)), "%)", " transitions ended with a divergence.\n", - "This may indicate insufficient exploration of the posterior distribution.\n", - "Possible remedies include: \n", - " * Increasing adapt_delta closer to 1 (default is 0.8) \n", - " * Reparameterizing the model (e.g. using a non-centered parameterization)\n", - " * Using informative or weakly informative prior distributions \n" + "See https://mc-stan.org/misc/warnings for details.\n" ) } } @@ -273,11 +269,8 @@ check_max_treedepth <- function(post_warmup_sampler_diagnostics, metadata) { percentage_max_treedepths <- 100 * num_max_treedepths / num_draws message( "Warning: ", num_max_treedepths, " of ", num_draws, " (", (base::format(round(percentage_max_treedepths, 0), nsmall = 1)), "%)", - " transitions hit the maximum treedepth limit of ", metadata$max_treedepth, - " or 2^", metadata$max_treedepth, "-1 leapfrog steps.\n", - "Trajectories that are prematurely terminated due to this limit will result in slow exploration.\n", - "Increasing the max_treedepth limit can avoid this at the expense of more computation.\n", - "If increasing max_treedepth does not remove warnings, try to reparameterize the model.\n" + " transitions hit the maximum treedepth limit of ", metadata$max_treedepth,".\n", + "See https://mc-stan.org/misc/warnings for details.\n" ) } } @@ -310,9 +303,8 @@ check_ebfmi <- function(post_warmup_sampler_diagnostics, threshold = 0.2) { if (any(efbmi_per_chain < threshold)) { message( "Warning: ", sum(efbmi_per_chain < threshold), " of ", length(efbmi_per_chain), - " chains had energy-based Bayesian fraction of missing information (E-BFMI)", - " less than ", threshold, ".", - "\nThis may indicate poor exploration of the posterior.\n" + " chains had an E-BFMI less than ", threshold, ".\n", + "See https://mc-stan.org/misc/warnings for details.\n" ) } invisible(unname(efbmi_per_chain)) diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R index da7065b28..147b1e143 100644 --- a/tests/testthat/test-utils.R +++ b/tests/testthat/test-utils.R @@ -49,7 +49,7 @@ test_that("check_divergences() works", { test_that("check_max_treedepth() works", { csv_files <- c(test_path("resources", "csv", "model1-2-no-warmup.csv")) csv_output <- read_cmdstan_csv(csv_files) - output <- "16 of 100 \\(16.0%\\) transitions hit the maximum treedepth limit of 5 or 2\\^5-1 leapfrog steps." + output <- "16 of 100 \\(16.0%\\) transitions hit the maximum treedepth limit of 5." expect_message( max_tds <- check_max_treedepth( csv_output$post_warmup_sampler_diagnostics, @@ -61,7 +61,7 @@ test_that("check_max_treedepth() works", { csv_files <- c(test_path("resources", "csv", "model1-2-no-warmup.csv"), test_path("resources", "csv", "model1-2-no-warmup.csv")) csv_output <- read_cmdstan_csv(csv_files) - output <- "32 of 200 \\(16.0%\\) transitions hit the maximum treedepth limit of 5 or 2\\^5-1 leapfrog steps." + output <- "32 of 200 \\(16.0%\\) transitions hit the maximum treedepth limit of 5." expect_message( max_tds <- check_max_treedepth( csv_output$post_warmup_sampler_diagnostics, @@ -72,7 +72,7 @@ test_that("check_max_treedepth() works", { # force different number of max treedepths per chain just to test csv_output$post_warmup_sampler_diagnostics[1, 1:2, "treedepth__"] <- c(1, 15) - output <- "31 of 200 \\(16.0%\\) transitions hit the maximum treedepth limit of 5 or 2\\^5-1 leapfrog steps." + output <- "31 of 200 \\(16.0%\\) transitions hit the maximum treedepth limit of 5." expect_message( max_tds <- check_max_treedepth( csv_output$post_warmup_sampler_diagnostics, @@ -83,7 +83,7 @@ test_that("check_max_treedepth() works", { csv_files <- c(test_path("resources", "csv", "model1-2-warmup.csv")) csv_output <- read_cmdstan_csv(csv_files) - output <- "1 of 100 \\(1.0%\\) transitions hit the maximum treedepth limit of 5 or 2\\^5-1 leapfrog steps." + output <- "1 of 100 \\(1.0%\\) transitions hit the maximum treedepth limit of 5." expect_message( check_max_treedepth( csv_output$post_warmup_sampler_diagnostics, @@ -102,7 +102,7 @@ test_that("check_ebfmi and computing ebfmi works", { energy_df$energy__[i+1] <- energy_df$energy__[i] + rnorm(1, 0, 0.01) } energy_df <- posterior::as_draws(energy_df) - expect_message(check_ebfmi(energy_df), "fraction of missing information \\(E-BFMI\\) less than") + expect_message(check_ebfmi(energy_df), "had an E-BFMI less than") energy_vec <- energy_df$energy__ check_val <- (sum(diff(energy_vec)^2) / length(energy_vec)) / stats::var(energy_vec) expect_equal(as.numeric(ebfmi(energy_df)), check_val)