Skip to content

Commit

Permalink
I/O for RNG state (#78)
Browse files Browse the repository at this point in the history
* Add new functions dqrng_get/set_state
* Reduce false positive misses in test coverage
  • Loading branch information
rstub authored Apr 7, 2024
1 parent dcaaa0e commit 60e0976
Show file tree
Hide file tree
Showing 14 changed files with 281 additions and 8 deletions.
2 changes: 1 addition & 1 deletion LICENSE.note
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ License: AGPL-3

Files: *
Copyright: 2018-2019 Ralf Stubner (daqana GmbH)
2022-2023 Ralf Stubner
2022-2024 Ralf Stubner
License: AGPL-3
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
export(dqRNGkind)
export(dqrexp)
export(dqrmvnorm)
export(dqrng_get_state)
export(dqrng_set_state)
export(dqrnorm)
export(dqrrademacher)
export(dqrunif)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
* Add Xoroshiro128\*\*/++ and Xoshiro256\*\*/++ to `xoshiro.h`
* Allow uniform and normal distributions to be registered as user-supplied RNG within R. This happens automatically if the option `dqrng.register_methods` is set to `TRUE`.
* Add missing inline attributes and limit the included Rcpp headers in `dqrng_types.h` ([#75](https://github.com/daqana/dqrng/pull/75) together with Paul Liétar)
* Add I/O methods for the RNG's internal state (fixing [#66](https://github.com/daqana/dqrng/issues/66) in [#78](https://github.com/daqana/dqrng/pull/78))


# dgrng 0.3.2

Expand Down
12 changes: 12 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@ dqRNGkind <- function(kind, normal_kind = "ignored") {
invisible(.Call(`_dqrng_dqRNGkind`, kind, normal_kind))
}

#' @rdname dqrng-functions
#' @export
dqrng_get_state <- function() {
.Call(`_dqrng_dqrng_get_state`)
}

#' @rdname dqrng-functions
#' @export
dqrng_set_state <- function(state) {
invisible(.Call(`_dqrng_dqrng_set_state`, state))
}

#' @rdname dqrng-functions
#' @export
dqrunif <- function(n, min = 0.0, max = 1.0) {
Expand Down
5 changes: 1 addition & 4 deletions R/dqrmv.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@
#' plot(x)
dqrmvnorm <- function(n, ...) {
if (!requireNamespace("mvtnorm", quietly = TRUE)) {
stop(
"Package \"mvtnorm\" must be installed to use this function.",
call. = FALSE
)
stop("Package \"mvtnorm\" must be installed to use this function.", call. = FALSE)
}
mvtnorm::rmvnorm(n, ..., rnorm = dqrnorm)
}
15 changes: 14 additions & 1 deletion R/dqset.seed.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
#' generates a random 64 bit integer and then uses each bit to generate
#' a 0/1 variable. This generates 64 integers per random number generation.
#'
#' \code{dqrng_get_state} and \code{dqrng_set_state} can be used to get and set
#' the RNG's internal state. The character vector should not be manipulated directly.
#'
#' @param seed integer scalar to seed the random number generator, or an integer vector of length 2 representing a 64-bit seed. Maybe \code{NULL}, see details.
#' @param stream integer used for selecting the RNG stream; either a scalar or a vector of length 2
#' @param kind string specifying the RNG (see details)
Expand All @@ -22,8 +25,11 @@
#' @param mean mean value of the normal distribution
#' @param sd standard deviation of the normal distribution
#' @param rate rate of the exponential distribution
#' @param state character vector representation of the RNG's internal state
#'
#' @return \code{dqrunif}, \code{dqrnorm}, and \code{dqrexp} return a numeric vector of length \code{n}. \code{dqrrademacher} returns an integer vector of length \code{n}.
#' @return \code{dqrunif}, \code{dqrnorm}, and \code{dqrexp} return a numeric vector
#' of length \code{n}. \code{dqrrademacher} returns an integer vector of length \code{n}.
#' \code{dqrng_get_state} returns a character vector representation of the RNG's internal state.
#'
#' @details Supported RNG kinds:
#' \describe{
Expand Down Expand Up @@ -75,6 +81,13 @@
#' dqrunif(5, min = 2, max = 10)
#' dqrexp(5, rate = 4)
#' dqrnorm(5, mean = 5, sd = 3)
#'
#' # get and restore the state
#' (state <- dqrng_get_state())
#' dqrunif(5)
#' dqrng_set_state(state)
#' dqrunif(5)
#'
#' @rdname dqrng-functions
#' @export
dqset.seed <- function(seed, stream = NULL) {
Expand Down
3 changes: 3 additions & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ coverage:
default:
target: auto
threshold: 1%
ignore:
- "R/zzz.R"
- "inst/include/pcg_*"
39 changes: 39 additions & 0 deletions inst/include/dqrng_RcppExports.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,45 @@ namespace dqrng {
throw Rcpp::exception(Rcpp::as<std::string>(rcpp_result_gen).c_str());
}

inline std::vector<std::string> dqrng_get_state() {
typedef SEXP(*Ptr_dqrng_get_state)();
static Ptr_dqrng_get_state p_dqrng_get_state = NULL;
if (p_dqrng_get_state == NULL) {
validateSignature("std::vector<std::string>(*dqrng_get_state)()");
p_dqrng_get_state = (Ptr_dqrng_get_state)R_GetCCallable("dqrng", "_dqrng_dqrng_get_state");
}
RObject rcpp_result_gen;
{
rcpp_result_gen = p_dqrng_get_state();
}
if (rcpp_result_gen.inherits("interrupted-error"))
throw Rcpp::internal::InterruptedException();
if (Rcpp::internal::isLongjumpSentinel(rcpp_result_gen))
throw Rcpp::LongjumpException(rcpp_result_gen);
if (rcpp_result_gen.inherits("try-error"))
throw Rcpp::exception(Rcpp::as<std::string>(rcpp_result_gen).c_str());
return Rcpp::as<std::vector<std::string> >(rcpp_result_gen);
}

inline void dqrng_set_state(std::vector<std::string> state) {
typedef SEXP(*Ptr_dqrng_set_state)(SEXP);
static Ptr_dqrng_set_state p_dqrng_set_state = NULL;
if (p_dqrng_set_state == NULL) {
validateSignature("void(*dqrng_set_state)(std::vector<std::string>)");
p_dqrng_set_state = (Ptr_dqrng_set_state)R_GetCCallable("dqrng", "_dqrng_dqrng_set_state");
}
RObject rcpp_result_gen;
{
rcpp_result_gen = p_dqrng_set_state(Shield<SEXP>(Rcpp::wrap(state)));
}
if (rcpp_result_gen.inherits("interrupted-error"))
throw Rcpp::internal::InterruptedException();
if (Rcpp::internal::isLongjumpSentinel(rcpp_result_gen))
throw Rcpp::LongjumpException(rcpp_result_gen);
if (rcpp_result_gen.inherits("try-error"))
throw Rcpp::exception(Rcpp::as<std::string>(rcpp_result_gen).c_str());
}

inline Rcpp::NumericVector dqrunif(size_t n, double min = 0.0, double max = 1.0) {
typedef SEXP(*Ptr_dqrunif)(SEXP,SEXP,SEXP);
static Ptr_dqrunif p_dqrunif = NULL;
Expand Down
22 changes: 21 additions & 1 deletion man/dqrng-functions.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

70 changes: 70 additions & 0 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,70 @@ RcppExport SEXP _dqrng_dqRNGkind(SEXP kindSEXP, SEXP normal_kindSEXP) {
UNPROTECT(1);
return rcpp_result_gen;
}
// dqrng_get_state
std::vector<std::string> dqrng_get_state();
static SEXP _dqrng_dqrng_get_state_try() {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
rcpp_result_gen = Rcpp::wrap(dqrng_get_state());
return rcpp_result_gen;
END_RCPP_RETURN_ERROR
}
RcppExport SEXP _dqrng_dqrng_get_state() {
SEXP rcpp_result_gen;
{
rcpp_result_gen = PROTECT(_dqrng_dqrng_get_state_try());
}
Rboolean rcpp_isInterrupt_gen = Rf_inherits(rcpp_result_gen, "interrupted-error");
if (rcpp_isInterrupt_gen) {
UNPROTECT(1);
Rf_onintr();
}
bool rcpp_isLongjump_gen = Rcpp::internal::isLongjumpSentinel(rcpp_result_gen);
if (rcpp_isLongjump_gen) {
Rcpp::internal::resumeJump(rcpp_result_gen);
}
Rboolean rcpp_isError_gen = Rf_inherits(rcpp_result_gen, "try-error");
if (rcpp_isError_gen) {
SEXP rcpp_msgSEXP_gen = Rf_asChar(rcpp_result_gen);
UNPROTECT(1);
Rf_error("%s", CHAR(rcpp_msgSEXP_gen));
}
UNPROTECT(1);
return rcpp_result_gen;
}
// dqrng_set_state
void dqrng_set_state(std::vector<std::string> state);
static SEXP _dqrng_dqrng_set_state_try(SEXP stateSEXP) {
BEGIN_RCPP
Rcpp::traits::input_parameter< std::vector<std::string> >::type state(stateSEXP);
dqrng_set_state(state);
return R_NilValue;
END_RCPP_RETURN_ERROR
}
RcppExport SEXP _dqrng_dqrng_set_state(SEXP stateSEXP) {
SEXP rcpp_result_gen;
{
rcpp_result_gen = PROTECT(_dqrng_dqrng_set_state_try(stateSEXP));
}
Rboolean rcpp_isInterrupt_gen = Rf_inherits(rcpp_result_gen, "interrupted-error");
if (rcpp_isInterrupt_gen) {
UNPROTECT(1);
Rf_onintr();
}
bool rcpp_isLongjump_gen = Rcpp::internal::isLongjumpSentinel(rcpp_result_gen);
if (rcpp_isLongjump_gen) {
Rcpp::internal::resumeJump(rcpp_result_gen);
}
Rboolean rcpp_isError_gen = Rf_inherits(rcpp_result_gen, "try-error");
if (rcpp_isError_gen) {
SEXP rcpp_msgSEXP_gen = Rf_asChar(rcpp_result_gen);
UNPROTECT(1);
Rf_error("%s", CHAR(rcpp_msgSEXP_gen));
}
UNPROTECT(1);
return rcpp_result_gen;
}
// dqrunif
Rcpp::NumericVector dqrunif(size_t n, double min, double max);
static SEXP _dqrng_dqrunif_try(SEXP nSEXP, SEXP minSEXP, SEXP maxSEXP) {
Expand Down Expand Up @@ -443,6 +507,8 @@ static int _dqrng_RcppExport_validate(const char* sig) {
if (signatures.empty()) {
signatures.insert("void(*dqset_seed)(Rcpp::Nullable<Rcpp::IntegerVector>,Rcpp::Nullable<Rcpp::IntegerVector>)");
signatures.insert("void(*dqRNGkind)(std::string,const std::string&)");
signatures.insert("std::vector<std::string>(*dqrng_get_state)()");
signatures.insert("void(*dqrng_set_state)(std::vector<std::string>)");
signatures.insert("Rcpp::NumericVector(*dqrunif)(size_t,double,double)");
signatures.insert("double(*runif)(double,double)");
signatures.insert("Rcpp::NumericVector(*dqrnorm)(size_t,double,double)");
Expand All @@ -461,6 +527,8 @@ static int _dqrng_RcppExport_validate(const char* sig) {
RcppExport SEXP _dqrng_RcppExport_registerCCallable() {
R_RegisterCCallable("dqrng", "_dqrng_dqset_seed", (DL_FUNC)_dqrng_dqset_seed_try);
R_RegisterCCallable("dqrng", "_dqrng_dqRNGkind", (DL_FUNC)_dqrng_dqRNGkind_try);
R_RegisterCCallable("dqrng", "_dqrng_dqrng_get_state", (DL_FUNC)_dqrng_dqrng_get_state_try);
R_RegisterCCallable("dqrng", "_dqrng_dqrng_set_state", (DL_FUNC)_dqrng_dqrng_set_state_try);
R_RegisterCCallable("dqrng", "_dqrng_dqrunif", (DL_FUNC)_dqrng_dqrunif_try);
R_RegisterCCallable("dqrng", "_dqrng_runif", (DL_FUNC)_dqrng_runif_try);
R_RegisterCCallable("dqrng", "_dqrng_dqrnorm", (DL_FUNC)_dqrng_dqrnorm_try);
Expand All @@ -478,6 +546,8 @@ RcppExport SEXP _dqrng_RcppExport_registerCCallable() {
static const R_CallMethodDef CallEntries[] = {
{"_dqrng_dqset_seed", (DL_FUNC) &_dqrng_dqset_seed, 2},
{"_dqrng_dqRNGkind", (DL_FUNC) &_dqrng_dqRNGkind, 2},
{"_dqrng_dqrng_get_state", (DL_FUNC) &_dqrng_dqrng_get_state, 0},
{"_dqrng_dqrng_set_state", (DL_FUNC) &_dqrng_dqrng_set_state, 1},
{"_dqrng_dqrunif", (DL_FUNC) &_dqrng_dqrunif, 3},
{"_dqrng_runif", (DL_FUNC) &_dqrng_runif, 2},
{"_dqrng_dqrnorm", (DL_FUNC) &_dqrng_dqrnorm, 3},
Expand Down
27 changes: 26 additions & 1 deletion src/dqrng.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Copyright 2018-2019 Ralf Stubner (daqana GmbH)
// Copyright 2022-2023 Ralf Stubner
// Copyright 2022-2024 Ralf Stubner
//
// This file is part of dqrng.
//
Expand Down Expand Up @@ -29,6 +29,7 @@

namespace {
dqrng::rng64_t rng = dqrng::generator();
std::string rng_kind = "default";

void init() {
Rcpp::RNGScope rngScope;
Expand Down Expand Up @@ -70,6 +71,7 @@ void dqRNGkind(std::string kind, const std::string& normal_kind = "ignored") {
for (auto & c: kind)
c = std::tolower(c);
uint64_t seed = rng->operator()();
rng_kind = kind;
if (kind == "default") {
rng = dqrng::generator(seed);
} else if (kind == "xoroshiro128+") {
Expand All @@ -89,6 +91,29 @@ void dqRNGkind(std::string kind, const std::string& normal_kind = "ignored") {
}
}

//' @rdname dqrng-functions
//' @export
// [[Rcpp::export(rng = false)]]
std::vector<std::string> dqrng_get_state() {
std::stringstream buffer;
buffer << rng_kind << " " << *rng;
std::vector<std::string> state{std::istream_iterator<std::string>{buffer},
std::istream_iterator<std::string>{}};
return state;
}

//' @rdname dqrng-functions
//' @export
// [[Rcpp::export(rng = false)]]
void dqrng_set_state(std::vector<std::string> state) {
std::stringstream buffer;
std::copy(state.begin() + 1,
state.end(),
std::ostream_iterator<std::string>(buffer, " "));
dqRNGkind(state[0]);
buffer >> *rng;
}

//' @rdname dqrng-functions
//' @export
// [[Rcpp::export(rng = false)]]
Expand Down
9 changes: 9 additions & 0 deletions tests/testthat/test-default.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ test_that("setting seed produces identical uniformly distributed numbers", {
expect_equal(u1, u2)
})

test_that("saving state produces identical uniformly distributed numbers", {
dqset.seed(seed)
state <- dqrng_get_state()
u1 <- dqrunif(10)
dqrng_set_state(state)
u2 <- dqrunif(10)
expect_identical(u1, u2)
})

test_that("setting seed produces identical uniformly distributed numbers (user defined RNG)", {
register_methods()
dqset.seed(seed)
Expand Down
Loading

0 comments on commit 60e0976

Please sign in to comment.