Skip to content

Commit

Permalink
Make exported RNG functions respect changes to R's seed
Browse files Browse the repository at this point in the history
  • Loading branch information
andrjohns committed May 17, 2024
1 parent c218b0d commit 5493eb2
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 7 deletions.
20 changes: 15 additions & 5 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -887,12 +887,16 @@ prep_fun_cpp <- function(fun_start, fun_end, model_lines) {
}
fun_body <- gsub("// [[stan::function]]", "// [[Rcpp::export]]\n", fun_body, fixed = TRUE)
fun_body <- gsub("std::ostream\\*\\s*pstream__\\s*=\\s*nullptr", "", fun_body)
if (cmdstan_version() < "2.35.0") {
fun_body <- gsub("boost::ecuyer1988&\\s*base_rng__", "SEXP base_rng_ptr", fun_body)
} else {
fun_body <- gsub("stan::rng_t&\\s*base_rng__", "SEXP base_rng_ptr", fun_body)
if (grepl("(stan::rng_t|boost::ecuyer1988)", fun_body)) {
if (cmdstan_version() < "2.35.0") {
fun_body <- gsub("boost::ecuyer1988&\\s*base_rng__", "SEXP base_rng_ptr", fun_body)
} else {
fun_body <- gsub("stan::rng_t&\\s*base_rng__", "SEXP base_rng_ptr", fun_body)
}
rng_seed <- "Rcpp::XPtr<stan::rng_t> base_rng(base_rng_ptr);base_rng->seed(get_seed());"
fun_body <- gsub("return", paste(rng_seed, "return"), fun_body)
fun_body <- gsub("base_rng__,", "*(base_rng.get()),", fun_body, fixed = TRUE)
}
fun_body <- gsub("base_rng__,", "*(Rcpp::XPtr<stan::rng_t>(base_rng_ptr).get()),", fun_body, fixed = TRUE)
fun_body <- gsub("pstream__", "&Rcpp::Rcout", fun_body, fixed = TRUE)
fun_body <- paste(fun_body, collapse = "\n")
gsub(pattern = ",\\s*)", replacement = ")", fun_body)
Expand Down Expand Up @@ -999,3 +1003,9 @@ expose_stan_functions <- function(function_env, global = FALSE, verbose = FALSE)
}
invisible(NULL)
}

# To allow for exported RNG functions to respect the R 'set.seed()' call,
# we need to derive a seed deterministically from the current RNG state
get_seed <- function() {
sample.int(.Machine$integer.max, 1)
}
10 changes: 10 additions & 0 deletions inst/include/stan_rng.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <boost/random/additive_combine.hpp>
#include <stan/version.hpp>
#include <Rcpp.h>

// A consistent rng_t is defined from 2.35 onwards
// so add a fallback for older versions
Expand All @@ -14,4 +15,13 @@ namespace stan {
}
#endif

// To ensure that exported RNG functions respect changes to R's RNG state,
// we need to deterministically set the seed of the RNG used by the exported
// functions.
int get_seed() {
Rcpp::Environment pkg = Rcpp::Environment::namespace_env("cmdstanr");
Rcpp::Function f = pkg["get_seed"];
return Rcpp::as<int>(f());
}

#endif
4 changes: 2 additions & 2 deletions tests/testthat/test-model-expose-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -352,12 +352,12 @@ test_that("rng functions can be exposed", {
expect_equal(
fit$functions$wrap_normal_rng(5,10),
# Stan RNG changed in 2.35
ifelse(cmdstan_version() < "2.35.0",-4.529876423, 0.02974925)
ifelse(cmdstan_version() < "2.35.0", 1.217251562, 20.49842178)
)

expect_equal(
fit$functions$wrap_normal_rng(5,10),
ifelse(cmdstan_version() < "2.35.0", 8.12959026, 10.3881349)
ifelse(cmdstan_version() < "2.35.0", -0.1426366567, 12.93498553)
)
})

Expand Down

0 comments on commit 5493eb2

Please sign in to comment.