From d72310fdd1b11479d1d00b878695ee586467ead9 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Thu, 16 May 2024 20:03:07 +0300 Subject: [PATCH] Changes to RNG handling --- R/utils.R | 14 +++++++++----- inst/include/base_rng.cpp | 6 +++--- tests/testthat/test-model-compile.R | 10 +++++++--- tests/testthat/test-model-expose-functions.R | 4 ++-- tests/testthat/test-profiling.R | 4 ++-- 5 files changed, 23 insertions(+), 15 deletions(-) diff --git a/R/utils.R b/R/utils.R index 8b0bc01a9..1715d58b3 100644 --- a/R/utils.R +++ b/R/utils.R @@ -727,7 +727,7 @@ get_cmdstan_flags <- function(flag_name) { paste(flags, collapse = " ") } -rcpp_source_stan <- function(code, env, verbose = FALSE) { +rcpp_source_stan <- function(code, env, verbose = FALSE, ...) { cxxflags <- get_cmdstan_flags("CXXFLAGS") cmdstanr_includes <- system.file("include", package = "cmdstanr", mustWork = TRUE) cmdstanr_includes <- paste0(" -I\"", cmdstanr_includes,"\"") @@ -746,7 +746,7 @@ rcpp_source_stan <- function(code, env, verbose = FALSE) { PKG_CXXFLAGS = paste0(cxxflags, cmdstanr_includes, collapse = " "), PKG_LIBS = libs ), - Rcpp::sourceCpp(code = code, env = env, verbose = verbose) + Rcpp::sourceCpp(code = code, env = env, verbose = verbose, ...) ) ) invisible(NULL) @@ -887,8 +887,12 @@ prep_fun_cpp <- function(fun_start, fun_end, model_lines) { } fun_body <- gsub("// [[stan::function]]", "// [[Rcpp::export]]\n", fun_body, fixed = TRUE) fun_body <- gsub("std::ostream\\*\\s*pstream__\\s*=\\s*nullptr", "", fun_body) - fun_body <- gsub("boost::ecuyer1988&\\s*base_rng__", "SEXP base_rng_ptr", fun_body) - fun_body <- gsub("base_rng__,", "*(Rcpp::XPtr(base_rng_ptr).get()),", fun_body, fixed = TRUE) + if (cmdstan_version() < "2.35.0") { + fun_body <- gsub("boost::ecuyer1988&\\s*base_rng__", "SEXP base_rng_ptr", fun_body) + } else { + fun_body <- gsub("stan::rng_t&\\s*base_rng__", "SEXP base_rng_ptr", fun_body) + } + fun_body <- gsub("base_rng__,", "*(Rcpp::XPtr(base_rng_ptr).get()),", fun_body, fixed = TRUE) fun_body <- gsub("pstream__", "&Rcpp::Rcout", fun_body, fixed = TRUE) fun_body <- paste(fun_body, collapse = "\n") gsub(pattern = ",\\s*)", replacement = ")", fun_body) @@ -935,7 +939,7 @@ compile_functions <- function(env, verbose = FALSE, global = FALSE) { if (length(rng_funs) > 0) { rng_cpp <- system.file("include", "base_rng.cpp", package = "cmdstanr", mustWork = TRUE) rcpp_source_stan(paste0(readLines(rng_cpp), collapse="\n"), env, verbose) - env$rng_ptr <- env$base_rng(seed=0) + env$rng_ptr <- env$base_rng(seed=1) } # For all RNG functions, pass the initialised Boost RNG by default diff --git a/inst/include/base_rng.cpp b/inst/include/base_rng.cpp index 7b38ba16f..42d9e7721 100644 --- a/inst/include/base_rng.cpp +++ b/inst/include/base_rng.cpp @@ -1,8 +1,8 @@ #include -#include +#include // [[Rcpp::export]] -SEXP base_rng(boost::uint32_t seed = 0) { - Rcpp::XPtr rng_ptr(new boost::ecuyer1988(seed)); +SEXP base_rng(boost::uint32_t seed = 1) { + Rcpp::XPtr rng_ptr(new stan::rng_t(seed)); return rng_ptr; } diff --git a/tests/testthat/test-model-compile.R b/tests/testthat/test-model-compile.R index 23595245f..2be8390f5 100644 --- a/tests/testthat/test-model-compile.R +++ b/tests/testthat/test-model-compile.R @@ -3,7 +3,7 @@ context("model-compile") set_cmdstan_path() stan_program <- cmdstan_example_file() mod <- cmdstan_model(stan_file = stan_program, compile = FALSE) - +cmdstan_make_local(cpp_options = list("PRECOMPILED_HEADERS"="false")) test_that("object initialized correctly", { expect_equal(mod$stan_file(), stan_program) @@ -130,6 +130,9 @@ test_that("name in STANCFLAGS is set correctly", { test_that("switching threads on and off works without rebuild", { main_path_o <- file.path(cmdstan_path(), "src", "cmdstan", "main.o") main_path_threads_o <- file.path(cmdstan_path(), "src", "cmdstan", "main_threads.o") + backup <- cmdstan_make_local() + no_threads <- grep("STAN_THREADS", backup, invert = TRUE, value = TRUE) + cmdstan_make_local(cpp_options = list(no_threads), append = FALSE) if (file.exists(main_path_threads_o)) { file.remove(main_path_threads_o) } @@ -155,6 +158,7 @@ test_that("switching threads on and off works without rebuild", { expect_equal(before_mtime, after_mtime) expect_warning(mod$compile(threads = TRUE, dry_run = TRUE), "deprecated") + cmdstan_make_local(cpp_options = backup, append = FALSE) }) test_that("multiple cpp_options work", { @@ -483,19 +487,19 @@ test_that("include_paths_stanc3_args() works", { test_that("cpp_options work with settings in make/local", { backup <- cmdstan_make_local() + no_threads <- grep("STAN_THREADS", backup, invert = TRUE, value = TRUE) + cmdstan_make_local(cpp_options = list(no_threads), append = FALSE) if (length(mod$exe_file()) > 0 && file.exists(mod$exe_file())) { file.remove(mod$exe_file()) } - cmdstan_make_local(cpp_options = list(), append = TRUE) rebuild_cmdstan() mod <- cmdstan_model(stan_file = stan_program) expect_null(mod$cpp_options()$STAN_THREADS) file.remove(mod$exe_file()) - cmdstan_make_local(cpp_options = backup, append = FALSE) cmdstan_make_local(cpp_options = list(stan_threads = TRUE), append = TRUE) file <- file.path(cmdstan_path(), "examples", "bernoulli", "bernoulli.stan") diff --git a/tests/testthat/test-model-expose-functions.R b/tests/testthat/test-model-expose-functions.R index fb82fe268..1c87b9b06 100644 --- a/tests/testthat/test-model-expose-functions.R +++ b/tests/testthat/test-model-expose-functions.R @@ -351,12 +351,12 @@ test_that("rng functions can be exposed", { expect_equal( fit$functions$wrap_normal_rng(5,10), - -4.5298764235381225873 + 0.02974925 ) expect_equal( fit$functions$wrap_normal_rng(5,10), - 8.1295902610102039887 + 10.3881349 ) }) diff --git a/tests/testthat/test-profiling.R b/tests/testthat/test-profiling.R index 15832df63..218969e7a 100644 --- a/tests/testthat/test-profiling.R +++ b/tests/testthat/test-profiling.R @@ -12,7 +12,7 @@ test_that("profiling works if profiling data is present", { profiles <- fit$profiles() expect_equal(length(profiles), 4) expect_equal(dim(profiles[[1]]), c(3,9)) - expect_equal(profiles[[1]][,"name"], c("glm", "priors", "udf")) + expect_equal(profiles[[1]][,"name"], c("udf", "priors", "glm")) file.remove(fit$profile_files()) expect_error( @@ -24,7 +24,7 @@ test_that("profiling works if profiling data is present", { profiles_no_csv <- fit$profiles() expect_equal(length(profiles_no_csv), 4) expect_equal(dim(profiles_no_csv[[1]]), c(3,9)) - expect_equal(profiles_no_csv[[1]][,"name"], c("glm", "priors", "udf")) + expect_equal(profiles_no_csv[[1]][,"name"], c("udf", "priors", "glm")) }) test_that("profiling errors if no profiling files are present", {