Skip to content

Commit

Permalink
Make exported RNG functions respect changes to R's seed (#973)
Browse files Browse the repository at this point in the history
* Make exported RNG functions respect changes to R's seed

* Simpler seed setting

* Seed set location

* Fix test
  • Loading branch information
andrjohns authored May 18, 2024
1 parent c218b0d commit b271b7b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 16 deletions.
17 changes: 12 additions & 5 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -887,12 +887,16 @@ prep_fun_cpp <- function(fun_start, fun_end, model_lines) {
}
fun_body <- gsub("// [[stan::function]]", "// [[Rcpp::export]]\n", fun_body, fixed = TRUE)
fun_body <- gsub("std::ostream\\*\\s*pstream__\\s*=\\s*nullptr", "", fun_body)
if (cmdstan_version() < "2.35.0") {
fun_body <- gsub("boost::ecuyer1988&\\s*base_rng__", "SEXP base_rng_ptr", fun_body)
} else {
fun_body <- gsub("stan::rng_t&\\s*base_rng__", "SEXP base_rng_ptr", fun_body)
if (grepl("(stan::rng_t|boost::ecuyer1988)", fun_body)) {
if (cmdstan_version() < "2.35.0") {
fun_body <- gsub("boost::ecuyer1988&\\s*base_rng__", "SEXP base_rng_ptr, SEXP seed", fun_body)
} else {
fun_body <- gsub("stan::rng_t&\\s*base_rng__", "SEXP base_rng_ptr, SEXP seed", fun_body)
}
rng_seed <- "Rcpp::XPtr<stan::rng_t> base_rng(base_rng_ptr);base_rng->seed(Rcpp::as<int>(seed));"
fun_body <- gsub("return", paste(rng_seed, "return"), fun_body)
fun_body <- gsub("base_rng__,", "*(base_rng.get()),", fun_body, fixed = TRUE)
}
fun_body <- gsub("base_rng__,", "*(Rcpp::XPtr<stan::rng_t>(base_rng_ptr).get()),", fun_body, fixed = TRUE)
fun_body <- gsub("pstream__", "&Rcpp::Rcout", fun_body, fixed = TRUE)
fun_body <- paste(fun_body, collapse = "\n")
gsub(pattern = ",\\s*)", replacement = ")", fun_body)
Expand Down Expand Up @@ -953,6 +957,9 @@ compile_functions <- function(env, verbose = FALSE, global = FALSE) {
fundef <- get(fun, envir = fun_env)
funargs <- formals(fundef)
funargs$base_rng_ptr <- env$rng_ptr
# To allow for exported RNG functions to respect the R 'set.seed()' call,
# we need to derive a seed deterministically from the current RNG state
funargs$seed <- quote(sample.int(.Machine$integer.max, 1))
formals(fundef) <- funargs
assign(fun, fundef, envir = fun_env)
}
Expand Down
19 changes: 8 additions & 11 deletions tests/testthat/test-model-expose-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -346,19 +346,16 @@ test_that("rng functions can be exposed", {
mod <- cmdstan_model(model, force_recompile = TRUE)
fit <- mod$sample(data = data_list)

set.seed(10)
fit$expose_functions(verbose = TRUE)
set.seed(10)
res1_1 <- fit$functions$wrap_normal_rng(5,10)
res2_1 <- fit$functions$wrap_normal_rng(5,10)
set.seed(10)
res1_2 <- fit$functions$wrap_normal_rng(5,10)
res2_2 <- fit$functions$wrap_normal_rng(5,10)

expect_equal(
fit$functions$wrap_normal_rng(5,10),
# Stan RNG changed in 2.35
ifelse(cmdstan_version() < "2.35.0",-4.529876423, 0.02974925)
)

expect_equal(
fit$functions$wrap_normal_rng(5,10),
ifelse(cmdstan_version() < "2.35.0", 8.12959026, 10.3881349)
)
expect_equal(res1_1, res1_2)
expect_equal(res2_1, res2_2)
})

test_that("Overloaded functions give meaningful errors", {
Expand Down

0 comments on commit b271b7b

Please sign in to comment.