From 5493eb2cf704a7e0a9326dd660b7ce4a9c2b61d6 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Fri, 17 May 2024 16:43:41 +0300 Subject: [PATCH 1/4] Make exported RNG functions respect changes to R's seed --- R/utils.R | 20 +++++++++++++++----- inst/include/stan_rng.hpp | 10 ++++++++++ tests/testthat/test-model-expose-functions.R | 4 ++-- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/R/utils.R b/R/utils.R index e07e1d57..964dc930 100644 --- a/R/utils.R +++ b/R/utils.R @@ -887,12 +887,16 @@ prep_fun_cpp <- function(fun_start, fun_end, model_lines) { } fun_body <- gsub("// [[stan::function]]", "// [[Rcpp::export]]\n", fun_body, fixed = TRUE) fun_body <- gsub("std::ostream\\*\\s*pstream__\\s*=\\s*nullptr", "", fun_body) - if (cmdstan_version() < "2.35.0") { - fun_body <- gsub("boost::ecuyer1988&\\s*base_rng__", "SEXP base_rng_ptr", fun_body) - } else { - fun_body <- gsub("stan::rng_t&\\s*base_rng__", "SEXP base_rng_ptr", fun_body) + if (grepl("(stan::rng_t|boost::ecuyer1988)", fun_body)) { + if (cmdstan_version() < "2.35.0") { + fun_body <- gsub("boost::ecuyer1988&\\s*base_rng__", "SEXP base_rng_ptr", fun_body) + } else { + fun_body <- gsub("stan::rng_t&\\s*base_rng__", "SEXP base_rng_ptr", fun_body) + } + rng_seed <- "Rcpp::XPtr base_rng(base_rng_ptr);base_rng->seed(get_seed());" + fun_body <- gsub("return", paste(rng_seed, "return"), fun_body) + fun_body <- gsub("base_rng__,", "*(base_rng.get()),", fun_body, fixed = TRUE) } - fun_body <- gsub("base_rng__,", "*(Rcpp::XPtr(base_rng_ptr).get()),", fun_body, fixed = TRUE) fun_body <- gsub("pstream__", "&Rcpp::Rcout", fun_body, fixed = TRUE) fun_body <- paste(fun_body, collapse = "\n") gsub(pattern = ",\\s*)", replacement = ")", fun_body) @@ -999,3 +1003,9 @@ expose_stan_functions <- function(function_env, global = FALSE, verbose = FALSE) } invisible(NULL) } + +# To allow for exported RNG functions to respect the R 'set.seed()' call, +# we need to derive a seed deterministically from the current RNG state +get_seed <- function() { + sample.int(.Machine$integer.max, 1) +} diff --git a/inst/include/stan_rng.hpp b/inst/include/stan_rng.hpp index 834464f7..dd36218f 100644 --- a/inst/include/stan_rng.hpp +++ b/inst/include/stan_rng.hpp @@ -3,6 +3,7 @@ #include #include +#include // A consistent rng_t is defined from 2.35 onwards // so add a fallback for older versions @@ -14,4 +15,13 @@ namespace stan { } #endif +// To ensure that exported RNG functions respect changes to R's RNG state, +// we need to deterministically set the seed of the RNG used by the exported +// functions. +int get_seed() { + Rcpp::Environment pkg = Rcpp::Environment::namespace_env("cmdstanr"); + Rcpp::Function f = pkg["get_seed"]; + return Rcpp::as(f()); +} + #endif diff --git a/tests/testthat/test-model-expose-functions.R b/tests/testthat/test-model-expose-functions.R index 1ee9b78e..50a5c372 100644 --- a/tests/testthat/test-model-expose-functions.R +++ b/tests/testthat/test-model-expose-functions.R @@ -352,12 +352,12 @@ test_that("rng functions can be exposed", { expect_equal( fit$functions$wrap_normal_rng(5,10), # Stan RNG changed in 2.35 - ifelse(cmdstan_version() < "2.35.0",-4.529876423, 0.02974925) + ifelse(cmdstan_version() < "2.35.0", 1.217251562, 20.49842178) ) expect_equal( fit$functions$wrap_normal_rng(5,10), - ifelse(cmdstan_version() < "2.35.0", 8.12959026, 10.3881349) + ifelse(cmdstan_version() < "2.35.0", -0.1426366567, 12.93498553) ) }) From 1d76ca72ea680c0502de16e45060d1561ea7d6a0 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Fri, 17 May 2024 22:18:20 +0300 Subject: [PATCH 2/4] Simpler seed setting --- R/utils.R | 15 ++++++--------- inst/include/stan_rng.hpp | 9 --------- 2 files changed, 6 insertions(+), 18 deletions(-) diff --git a/R/utils.R b/R/utils.R index 964dc930..718b6ddd 100644 --- a/R/utils.R +++ b/R/utils.R @@ -889,11 +889,11 @@ prep_fun_cpp <- function(fun_start, fun_end, model_lines) { fun_body <- gsub("std::ostream\\*\\s*pstream__\\s*=\\s*nullptr", "", fun_body) if (grepl("(stan::rng_t|boost::ecuyer1988)", fun_body)) { if (cmdstan_version() < "2.35.0") { - fun_body <- gsub("boost::ecuyer1988&\\s*base_rng__", "SEXP base_rng_ptr", fun_body) + fun_body <- gsub("boost::ecuyer1988&\\s*base_rng__", "SEXP base_rng_ptr, SEXP seed", fun_body) } else { - fun_body <- gsub("stan::rng_t&\\s*base_rng__", "SEXP base_rng_ptr", fun_body) + fun_body <- gsub("stan::rng_t&\\s*base_rng__", "SEXP base_rng_ptr, SEXP seed", fun_body) } - rng_seed <- "Rcpp::XPtr base_rng(base_rng_ptr);base_rng->seed(get_seed());" + rng_seed <- "Rcpp::XPtr base_rng(base_rng_ptr);base_rng->seed(Rcpp::as(seed));" fun_body <- gsub("return", paste(rng_seed, "return"), fun_body) fun_body <- gsub("base_rng__,", "*(base_rng.get()),", fun_body, fixed = TRUE) } @@ -957,6 +957,9 @@ compile_functions <- function(env, verbose = FALSE, global = FALSE) { fundef <- get(fun, envir = fun_env) funargs <- formals(fundef) funargs$base_rng_ptr <- env$rng_ptr + # To allow for exported RNG functions to respect the R 'set.seed()' call, + # we need to derive a seed deterministically from the current RNG state + funargs$seed <- quote(sample.int(.Machine$integer.max, 1)) formals(fundef) <- funargs assign(fun, fundef, envir = fun_env) } @@ -1003,9 +1006,3 @@ expose_stan_functions <- function(function_env, global = FALSE, verbose = FALSE) } invisible(NULL) } - -# To allow for exported RNG functions to respect the R 'set.seed()' call, -# we need to derive a seed deterministically from the current RNG state -get_seed <- function() { - sample.int(.Machine$integer.max, 1) -} diff --git a/inst/include/stan_rng.hpp b/inst/include/stan_rng.hpp index dd36218f..8dc55dde 100644 --- a/inst/include/stan_rng.hpp +++ b/inst/include/stan_rng.hpp @@ -15,13 +15,4 @@ namespace stan { } #endif -// To ensure that exported RNG functions respect changes to R's RNG state, -// we need to deterministically set the seed of the RNG used by the exported -// functions. -int get_seed() { - Rcpp::Environment pkg = Rcpp::Environment::namespace_env("cmdstanr"); - Rcpp::Function f = pkg["get_seed"]; - return Rcpp::as(f()); -} - #endif From 707c1fb2c1262e21fe5d830a56d584040275e083 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Fri, 17 May 2024 23:55:22 +0300 Subject: [PATCH 3/4] Seed set location --- inst/include/stan_rng.hpp | 1 - tests/testthat/test-model-expose-functions.R | 24 +++++++++++++++++--- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/inst/include/stan_rng.hpp b/inst/include/stan_rng.hpp index 8dc55dde..834464f7 100644 --- a/inst/include/stan_rng.hpp +++ b/inst/include/stan_rng.hpp @@ -3,7 +3,6 @@ #include #include -#include // A consistent rng_t is defined from 2.35 onwards // so add a fallback for older versions diff --git a/tests/testthat/test-model-expose-functions.R b/tests/testthat/test-model-expose-functions.R index 50a5c372..863b7f29 100644 --- a/tests/testthat/test-model-expose-functions.R +++ b/tests/testthat/test-model-expose-functions.R @@ -346,17 +346,35 @@ test_that("rng functions can be exposed", { mod <- cmdstan_model(model, force_recompile = TRUE) fit <- mod$sample(data = data_list) - set.seed(10) fit$expose_functions(verbose = TRUE) + set.seed(10) + res <- fit$functions$wrap_normal_rng(5,10) + + expect_equal( + res, + # Stan RNG changed in 2.35 + ifelse(cmdstan_version() < "2.35.0", 1.217251562, 20.49842178) + ) + res <- fit$functions$wrap_normal_rng(5,10) + + expect_equal( + res, + ifelse(cmdstan_version() < "2.35.0", -0.1426366567, 12.93498553) + ) + + # Test that the RNG function respects set.seed + set.seed(10) + res <- fit$functions$wrap_normal_rng(5,10) expect_equal( - fit$functions$wrap_normal_rng(5,10), + res, # Stan RNG changed in 2.35 ifelse(cmdstan_version() < "2.35.0", 1.217251562, 20.49842178) ) + res <- fit$functions$wrap_normal_rng(5,10) expect_equal( - fit$functions$wrap_normal_rng(5,10), + res, ifelse(cmdstan_version() < "2.35.0", -0.1426366567, 12.93498553) ) }) From 9068bc731d6e9098c9f897ca6c2a5a34d60fff5c Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Sat, 18 May 2024 00:30:04 +0300 Subject: [PATCH 4/4] Fix test --- tests/testthat/test-model-expose-functions.R | 33 ++++---------------- 1 file changed, 6 insertions(+), 27 deletions(-) diff --git a/tests/testthat/test-model-expose-functions.R b/tests/testthat/test-model-expose-functions.R index 863b7f29..516d39c1 100644 --- a/tests/testthat/test-model-expose-functions.R +++ b/tests/testthat/test-model-expose-functions.R @@ -348,35 +348,14 @@ test_that("rng functions can be exposed", { fit$expose_functions(verbose = TRUE) set.seed(10) - res <- fit$functions$wrap_normal_rng(5,10) - - expect_equal( - res, - # Stan RNG changed in 2.35 - ifelse(cmdstan_version() < "2.35.0", 1.217251562, 20.49842178) - ) - res <- fit$functions$wrap_normal_rng(5,10) - - expect_equal( - res, - ifelse(cmdstan_version() < "2.35.0", -0.1426366567, 12.93498553) - ) - - # Test that the RNG function respects set.seed + res1_1 <- fit$functions$wrap_normal_rng(5,10) + res2_1 <- fit$functions$wrap_normal_rng(5,10) set.seed(10) - res <- fit$functions$wrap_normal_rng(5,10) - - expect_equal( - res, - # Stan RNG changed in 2.35 - ifelse(cmdstan_version() < "2.35.0", 1.217251562, 20.49842178) - ) - res <- fit$functions$wrap_normal_rng(5,10) + res1_2 <- fit$functions$wrap_normal_rng(5,10) + res2_2 <- fit$functions$wrap_normal_rng(5,10) - expect_equal( - res, - ifelse(cmdstan_version() < "2.35.0", -0.1426366567, 12.93498553) - ) + expect_equal(res1_1, res1_2) + expect_equal(res2_1, res2_2) }) test_that("Overloaded functions give meaningful errors", {