Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I/O for RNG state #78

Merged
merged 4 commits into from
Apr 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion LICENSE.note
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ License: AGPL-3

Files: *
Copyright: 2018-2019 Ralf Stubner (daqana GmbH)
2022-2023 Ralf Stubner
2022-2024 Ralf Stubner
License: AGPL-3
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
export(dqRNGkind)
export(dqrexp)
export(dqrmvnorm)
export(dqrng_get_state)
export(dqrng_set_state)
export(dqrnorm)
export(dqrrademacher)
export(dqrunif)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
* Add Xoroshiro128\*\*/++ and Xoshiro256\*\*/++ to `xoshiro.h`
* Allow uniform and normal distributions to be registered as user-supplied RNG within R. This happens automatically if the option `dqrng.register_methods` is set to `TRUE`.
* Add missing inline attributes and limit the included Rcpp headers in `dqrng_types.h` ([#75](https://github.com/daqana/dqrng/pull/75) together with Paul Liétar)
* Add I/O methods for the RNG's internal state (fixing [#66](https://github.com/daqana/dqrng/issues/66) in [#78](https://github.com/daqana/dqrng/pull/78))


# dgrng 0.3.2

Expand Down
12 changes: 12 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@ dqRNGkind <- function(kind, normal_kind = "ignored") {
invisible(.Call(`_dqrng_dqRNGkind`, kind, normal_kind))
}

#' @rdname dqrng-functions
#' @export
dqrng_get_state <- function() {
.Call(`_dqrng_dqrng_get_state`)
}

#' @rdname dqrng-functions
#' @export
dqrng_set_state <- function(state) {
invisible(.Call(`_dqrng_dqrng_set_state`, state))
}

#' @rdname dqrng-functions
#' @export
dqrunif <- function(n, min = 0.0, max = 1.0) {
Expand Down
5 changes: 1 addition & 4 deletions R/dqrmv.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@
#' plot(x)
dqrmvnorm <- function(n, ...) {
if (!requireNamespace("mvtnorm", quietly = TRUE)) {
stop(
"Package \"mvtnorm\" must be installed to use this function.",
call. = FALSE
)
stop("Package \"mvtnorm\" must be installed to use this function.", call. = FALSE)
}
mvtnorm::rmvnorm(n, ..., rnorm = dqrnorm)
}
15 changes: 14 additions & 1 deletion R/dqset.seed.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
#' generates a random 64 bit integer and then uses each bit to generate
#' a 0/1 variable. This generates 64 integers per random number generation.
#'
#' \code{dqrng_get_state} and \code{dqrng_set_state} can be used to get and set
#' the RNG's internal state. The character vector should not be manipulated directly.
#'
#' @param seed integer scalar to seed the random number generator, or an integer vector of length 2 representing a 64-bit seed. Maybe \code{NULL}, see details.
#' @param stream integer used for selecting the RNG stream; either a scalar or a vector of length 2
#' @param kind string specifying the RNG (see details)
Expand All @@ -22,8 +25,11 @@
#' @param mean mean value of the normal distribution
#' @param sd standard deviation of the normal distribution
#' @param rate rate of the exponential distribution
#' @param state character vector representation of the RNG's internal state
#'
#' @return \code{dqrunif}, \code{dqrnorm}, and \code{dqrexp} return a numeric vector of length \code{n}. \code{dqrrademacher} returns an integer vector of length \code{n}.
#' @return \code{dqrunif}, \code{dqrnorm}, and \code{dqrexp} return a numeric vector
#' of length \code{n}. \code{dqrrademacher} returns an integer vector of length \code{n}.
#' \code{dqrng_get_state} returns a character vector representation of the RNG's internal state.
#'
#' @details Supported RNG kinds:
#' \describe{
Expand Down Expand Up @@ -75,6 +81,13 @@
#' dqrunif(5, min = 2, max = 10)
#' dqrexp(5, rate = 4)
#' dqrnorm(5, mean = 5, sd = 3)
#'
#' # get and restore the state
#' (state <- dqrng_get_state())
#' dqrunif(5)
#' dqrng_set_state(state)
#' dqrunif(5)
#'
#' @rdname dqrng-functions
#' @export
dqset.seed <- function(seed, stream = NULL) {
Expand Down
3 changes: 3 additions & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ coverage:
default:
target: auto
threshold: 1%
ignore:
- "R/zzz.R"
- "inst/include/pcg_*"
39 changes: 39 additions & 0 deletions inst/include/dqrng_RcppExports.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,45 @@ namespace dqrng {
throw Rcpp::exception(Rcpp::as<std::string>(rcpp_result_gen).c_str());
}

inline std::vector<std::string> dqrng_get_state() {
typedef SEXP(*Ptr_dqrng_get_state)();
static Ptr_dqrng_get_state p_dqrng_get_state = NULL;
if (p_dqrng_get_state == NULL) {
validateSignature("std::vector<std::string>(*dqrng_get_state)()");
p_dqrng_get_state = (Ptr_dqrng_get_state)R_GetCCallable("dqrng", "_dqrng_dqrng_get_state");
}
RObject rcpp_result_gen;
{
rcpp_result_gen = p_dqrng_get_state();
}
if (rcpp_result_gen.inherits("interrupted-error"))
throw Rcpp::internal::InterruptedException();
if (Rcpp::internal::isLongjumpSentinel(rcpp_result_gen))
throw Rcpp::LongjumpException(rcpp_result_gen);
if (rcpp_result_gen.inherits("try-error"))
throw Rcpp::exception(Rcpp::as<std::string>(rcpp_result_gen).c_str());
return Rcpp::as<std::vector<std::string> >(rcpp_result_gen);
}

inline void dqrng_set_state(std::vector<std::string> state) {
typedef SEXP(*Ptr_dqrng_set_state)(SEXP);
static Ptr_dqrng_set_state p_dqrng_set_state = NULL;
if (p_dqrng_set_state == NULL) {
validateSignature("void(*dqrng_set_state)(std::vector<std::string>)");
p_dqrng_set_state = (Ptr_dqrng_set_state)R_GetCCallable("dqrng", "_dqrng_dqrng_set_state");
}
RObject rcpp_result_gen;
{
rcpp_result_gen = p_dqrng_set_state(Shield<SEXP>(Rcpp::wrap(state)));
}
if (rcpp_result_gen.inherits("interrupted-error"))
throw Rcpp::internal::InterruptedException();
if (Rcpp::internal::isLongjumpSentinel(rcpp_result_gen))
throw Rcpp::LongjumpException(rcpp_result_gen);
if (rcpp_result_gen.inherits("try-error"))
throw Rcpp::exception(Rcpp::as<std::string>(rcpp_result_gen).c_str());
}

inline Rcpp::NumericVector dqrunif(size_t n, double min = 0.0, double max = 1.0) {
typedef SEXP(*Ptr_dqrunif)(SEXP,SEXP,SEXP);
static Ptr_dqrunif p_dqrunif = NULL;
Expand Down
22 changes: 21 additions & 1 deletion man/dqrng-functions.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

70 changes: 70 additions & 0 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,70 @@ RcppExport SEXP _dqrng_dqRNGkind(SEXP kindSEXP, SEXP normal_kindSEXP) {
UNPROTECT(1);
return rcpp_result_gen;
}
// dqrng_get_state
std::vector<std::string> dqrng_get_state();
static SEXP _dqrng_dqrng_get_state_try() {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
rcpp_result_gen = Rcpp::wrap(dqrng_get_state());
return rcpp_result_gen;
END_RCPP_RETURN_ERROR
}
RcppExport SEXP _dqrng_dqrng_get_state() {
SEXP rcpp_result_gen;
{
rcpp_result_gen = PROTECT(_dqrng_dqrng_get_state_try());
}
Rboolean rcpp_isInterrupt_gen = Rf_inherits(rcpp_result_gen, "interrupted-error");
if (rcpp_isInterrupt_gen) {
UNPROTECT(1);
Rf_onintr();
}
bool rcpp_isLongjump_gen = Rcpp::internal::isLongjumpSentinel(rcpp_result_gen);
if (rcpp_isLongjump_gen) {
Rcpp::internal::resumeJump(rcpp_result_gen);
}
Rboolean rcpp_isError_gen = Rf_inherits(rcpp_result_gen, "try-error");
if (rcpp_isError_gen) {
SEXP rcpp_msgSEXP_gen = Rf_asChar(rcpp_result_gen);
UNPROTECT(1);
Rf_error("%s", CHAR(rcpp_msgSEXP_gen));
}
UNPROTECT(1);
return rcpp_result_gen;
}
// dqrng_set_state
void dqrng_set_state(std::vector<std::string> state);
static SEXP _dqrng_dqrng_set_state_try(SEXP stateSEXP) {
BEGIN_RCPP
Rcpp::traits::input_parameter< std::vector<std::string> >::type state(stateSEXP);
dqrng_set_state(state);
return R_NilValue;
END_RCPP_RETURN_ERROR
}
RcppExport SEXP _dqrng_dqrng_set_state(SEXP stateSEXP) {
SEXP rcpp_result_gen;
{
rcpp_result_gen = PROTECT(_dqrng_dqrng_set_state_try(stateSEXP));
}
Rboolean rcpp_isInterrupt_gen = Rf_inherits(rcpp_result_gen, "interrupted-error");
if (rcpp_isInterrupt_gen) {
UNPROTECT(1);
Rf_onintr();
}
bool rcpp_isLongjump_gen = Rcpp::internal::isLongjumpSentinel(rcpp_result_gen);
if (rcpp_isLongjump_gen) {
Rcpp::internal::resumeJump(rcpp_result_gen);
}
Rboolean rcpp_isError_gen = Rf_inherits(rcpp_result_gen, "try-error");
if (rcpp_isError_gen) {
SEXP rcpp_msgSEXP_gen = Rf_asChar(rcpp_result_gen);
UNPROTECT(1);
Rf_error("%s", CHAR(rcpp_msgSEXP_gen));
}
UNPROTECT(1);
return rcpp_result_gen;
}
// dqrunif
Rcpp::NumericVector dqrunif(size_t n, double min, double max);
static SEXP _dqrng_dqrunif_try(SEXP nSEXP, SEXP minSEXP, SEXP maxSEXP) {
Expand Down Expand Up @@ -443,6 +507,8 @@ static int _dqrng_RcppExport_validate(const char* sig) {
if (signatures.empty()) {
signatures.insert("void(*dqset_seed)(Rcpp::Nullable<Rcpp::IntegerVector>,Rcpp::Nullable<Rcpp::IntegerVector>)");
signatures.insert("void(*dqRNGkind)(std::string,const std::string&)");
signatures.insert("std::vector<std::string>(*dqrng_get_state)()");
signatures.insert("void(*dqrng_set_state)(std::vector<std::string>)");
signatures.insert("Rcpp::NumericVector(*dqrunif)(size_t,double,double)");
signatures.insert("double(*runif)(double,double)");
signatures.insert("Rcpp::NumericVector(*dqrnorm)(size_t,double,double)");
Expand All @@ -461,6 +527,8 @@ static int _dqrng_RcppExport_validate(const char* sig) {
RcppExport SEXP _dqrng_RcppExport_registerCCallable() {
R_RegisterCCallable("dqrng", "_dqrng_dqset_seed", (DL_FUNC)_dqrng_dqset_seed_try);
R_RegisterCCallable("dqrng", "_dqrng_dqRNGkind", (DL_FUNC)_dqrng_dqRNGkind_try);
R_RegisterCCallable("dqrng", "_dqrng_dqrng_get_state", (DL_FUNC)_dqrng_dqrng_get_state_try);
R_RegisterCCallable("dqrng", "_dqrng_dqrng_set_state", (DL_FUNC)_dqrng_dqrng_set_state_try);
R_RegisterCCallable("dqrng", "_dqrng_dqrunif", (DL_FUNC)_dqrng_dqrunif_try);
R_RegisterCCallable("dqrng", "_dqrng_runif", (DL_FUNC)_dqrng_runif_try);
R_RegisterCCallable("dqrng", "_dqrng_dqrnorm", (DL_FUNC)_dqrng_dqrnorm_try);
Expand All @@ -478,6 +546,8 @@ RcppExport SEXP _dqrng_RcppExport_registerCCallable() {
static const R_CallMethodDef CallEntries[] = {
{"_dqrng_dqset_seed", (DL_FUNC) &_dqrng_dqset_seed, 2},
{"_dqrng_dqRNGkind", (DL_FUNC) &_dqrng_dqRNGkind, 2},
{"_dqrng_dqrng_get_state", (DL_FUNC) &_dqrng_dqrng_get_state, 0},
{"_dqrng_dqrng_set_state", (DL_FUNC) &_dqrng_dqrng_set_state, 1},
{"_dqrng_dqrunif", (DL_FUNC) &_dqrng_dqrunif, 3},
{"_dqrng_runif", (DL_FUNC) &_dqrng_runif, 2},
{"_dqrng_dqrnorm", (DL_FUNC) &_dqrng_dqrnorm, 3},
Expand Down
27 changes: 26 additions & 1 deletion src/dqrng.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Copyright 2018-2019 Ralf Stubner (daqana GmbH)
// Copyright 2022-2023 Ralf Stubner
// Copyright 2022-2024 Ralf Stubner
//
// This file is part of dqrng.
//
Expand Down Expand Up @@ -29,6 +29,7 @@

namespace {
dqrng::rng64_t rng = dqrng::generator();
std::string rng_kind = "default";

void init() {
Rcpp::RNGScope rngScope;
Expand Down Expand Up @@ -70,6 +71,7 @@ void dqRNGkind(std::string kind, const std::string& normal_kind = "ignored") {
for (auto & c: kind)
c = std::tolower(c);
uint64_t seed = rng->operator()();
rng_kind = kind;
if (kind == "default") {
rng = dqrng::generator(seed);
} else if (kind == "xoroshiro128+") {
Expand All @@ -89,6 +91,29 @@ void dqRNGkind(std::string kind, const std::string& normal_kind = "ignored") {
}
}

//' @rdname dqrng-functions
//' @export
// [[Rcpp::export(rng = false)]]
std::vector<std::string> dqrng_get_state() {
std::stringstream buffer;
buffer << rng_kind << " " << *rng;
std::vector<std::string> state{std::istream_iterator<std::string>{buffer},
std::istream_iterator<std::string>{}};
return state;
}

//' @rdname dqrng-functions
//' @export
// [[Rcpp::export(rng = false)]]
void dqrng_set_state(std::vector<std::string> state) {
std::stringstream buffer;
std::copy(state.begin() + 1,
state.end(),
std::ostream_iterator<std::string>(buffer, " "));
dqRNGkind(state[0]);
buffer >> *rng;
}

//' @rdname dqrng-functions
//' @export
// [[Rcpp::export(rng = false)]]
Expand Down
9 changes: 9 additions & 0 deletions tests/testthat/test-default.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ test_that("setting seed produces identical uniformly distributed numbers", {
expect_equal(u1, u2)
})

test_that("saving state produces identical uniformly distributed numbers", {
dqset.seed(seed)
state <- dqrng_get_state()
u1 <- dqrunif(10)
dqrng_set_state(state)
u2 <- dqrunif(10)
expect_identical(u1, u2)
})

test_that("setting seed produces identical uniformly distributed numbers (user defined RNG)", {
register_methods()
dqset.seed(seed)
Expand Down
Loading