Skip to content

Commit

Permalink
Changes to RNG handling
Browse files Browse the repository at this point in the history
  • Loading branch information
andrjohns committed May 16, 2024
1 parent 05dd87b commit d72310f
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 15 deletions.
14 changes: 9 additions & 5 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ get_cmdstan_flags <- function(flag_name) {
paste(flags, collapse = " ")
}

rcpp_source_stan <- function(code, env, verbose = FALSE) {
rcpp_source_stan <- function(code, env, verbose = FALSE, ...) {
cxxflags <- get_cmdstan_flags("CXXFLAGS")
cmdstanr_includes <- system.file("include", package = "cmdstanr", mustWork = TRUE)
cmdstanr_includes <- paste0(" -I\"", cmdstanr_includes,"\"")
Expand All @@ -746,7 +746,7 @@ rcpp_source_stan <- function(code, env, verbose = FALSE) {
PKG_CXXFLAGS = paste0(cxxflags, cmdstanr_includes, collapse = " "),
PKG_LIBS = libs
),
Rcpp::sourceCpp(code = code, env = env, verbose = verbose)
Rcpp::sourceCpp(code = code, env = env, verbose = verbose, ...)
)
)
invisible(NULL)
Expand Down Expand Up @@ -887,8 +887,12 @@ prep_fun_cpp <- function(fun_start, fun_end, model_lines) {
}
fun_body <- gsub("// [[stan::function]]", "// [[Rcpp::export]]\n", fun_body, fixed = TRUE)
fun_body <- gsub("std::ostream\\*\\s*pstream__\\s*=\\s*nullptr", "", fun_body)
fun_body <- gsub("boost::ecuyer1988&\\s*base_rng__", "SEXP base_rng_ptr", fun_body)
fun_body <- gsub("base_rng__,", "*(Rcpp::XPtr<boost::ecuyer1988>(base_rng_ptr).get()),", fun_body, fixed = TRUE)
if (cmdstan_version() < "2.35.0") {
fun_body <- gsub("boost::ecuyer1988&\\s*base_rng__", "SEXP base_rng_ptr", fun_body)
} else {
fun_body <- gsub("stan::rng_t&\\s*base_rng__", "SEXP base_rng_ptr", fun_body)
}
fun_body <- gsub("base_rng__,", "*(Rcpp::XPtr<stan::rng_t>(base_rng_ptr).get()),", fun_body, fixed = TRUE)
fun_body <- gsub("pstream__", "&Rcpp::Rcout", fun_body, fixed = TRUE)
fun_body <- paste(fun_body, collapse = "\n")
gsub(pattern = ",\\s*)", replacement = ")", fun_body)
Expand Down Expand Up @@ -935,7 +939,7 @@ compile_functions <- function(env, verbose = FALSE, global = FALSE) {
if (length(rng_funs) > 0) {
rng_cpp <- system.file("include", "base_rng.cpp", package = "cmdstanr", mustWork = TRUE)
rcpp_source_stan(paste0(readLines(rng_cpp), collapse="\n"), env, verbose)
env$rng_ptr <- env$base_rng(seed=0)
env$rng_ptr <- env$base_rng(seed=1)
}

# For all RNG functions, pass the initialised Boost RNG by default
Expand Down
6 changes: 3 additions & 3 deletions inst/include/base_rng.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#include <Rcpp.h>
#include <boost/random/additive_combine.hpp>
#include <stan_rng.hpp>

// [[Rcpp::export]]
SEXP base_rng(boost::uint32_t seed = 0) {
Rcpp::XPtr<boost::ecuyer1988> rng_ptr(new boost::ecuyer1988(seed));
SEXP base_rng(boost::uint32_t seed = 1) {
Rcpp::XPtr<stan::rng_t> rng_ptr(new stan::rng_t(seed));
return rng_ptr;
}
10 changes: 7 additions & 3 deletions tests/testthat/test-model-compile.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ context("model-compile")
set_cmdstan_path()
stan_program <- cmdstan_example_file()
mod <- cmdstan_model(stan_file = stan_program, compile = FALSE)

cmdstan_make_local(cpp_options = list("PRECOMPILED_HEADERS"="false"))

test_that("object initialized correctly", {
expect_equal(mod$stan_file(), stan_program)
Expand Down Expand Up @@ -130,6 +130,9 @@ test_that("name in STANCFLAGS is set correctly", {
test_that("switching threads on and off works without rebuild", {
main_path_o <- file.path(cmdstan_path(), "src", "cmdstan", "main.o")
main_path_threads_o <- file.path(cmdstan_path(), "src", "cmdstan", "main_threads.o")
backup <- cmdstan_make_local()
no_threads <- grep("STAN_THREADS", backup, invert = TRUE, value = TRUE)
cmdstan_make_local(cpp_options = list(no_threads), append = FALSE)
if (file.exists(main_path_threads_o)) {
file.remove(main_path_threads_o)
}
Expand All @@ -155,6 +158,7 @@ test_that("switching threads on and off works without rebuild", {
expect_equal(before_mtime, after_mtime)

expect_warning(mod$compile(threads = TRUE, dry_run = TRUE), "deprecated")
cmdstan_make_local(cpp_options = backup, append = FALSE)
})

test_that("multiple cpp_options work", {
Expand Down Expand Up @@ -483,19 +487,19 @@ test_that("include_paths_stanc3_args() works", {

test_that("cpp_options work with settings in make/local", {
backup <- cmdstan_make_local()
no_threads <- grep("STAN_THREADS", backup, invert = TRUE, value = TRUE)
cmdstan_make_local(cpp_options = list(no_threads), append = FALSE)

if (length(mod$exe_file()) > 0 && file.exists(mod$exe_file())) {
file.remove(mod$exe_file())
}

cmdstan_make_local(cpp_options = list(), append = TRUE)
rebuild_cmdstan()
mod <- cmdstan_model(stan_file = stan_program)
expect_null(mod$cpp_options()$STAN_THREADS)

file.remove(mod$exe_file())

cmdstan_make_local(cpp_options = backup, append = FALSE)
cmdstan_make_local(cpp_options = list(stan_threads = TRUE), append = TRUE)

file <- file.path(cmdstan_path(), "examples", "bernoulli", "bernoulli.stan")
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test-model-expose-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -351,12 +351,12 @@ test_that("rng functions can be exposed", {

expect_equal(
fit$functions$wrap_normal_rng(5,10),
-4.5298764235381225873
0.02974925
)

expect_equal(
fit$functions$wrap_normal_rng(5,10),
8.1295902610102039887
10.3881349
)
})

Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test-profiling.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ test_that("profiling works if profiling data is present", {
profiles <- fit$profiles()
expect_equal(length(profiles), 4)
expect_equal(dim(profiles[[1]]), c(3,9))
expect_equal(profiles[[1]][,"name"], c("glm", "priors", "udf"))
expect_equal(profiles[[1]][,"name"], c("udf", "priors", "glm"))

file.remove(fit$profile_files())
expect_error(
Expand All @@ -24,7 +24,7 @@ test_that("profiling works if profiling data is present", {
profiles_no_csv <- fit$profiles()
expect_equal(length(profiles_no_csv), 4)
expect_equal(dim(profiles_no_csv[[1]]), c(3,9))
expect_equal(profiles_no_csv[[1]][,"name"], c("glm", "priors", "udf"))
expect_equal(profiles_no_csv[[1]][,"name"], c("udf", "priors", "glm"))
})

test_that("profiling errors if no profiling files are present", {
Expand Down

0 comments on commit d72310f

Please sign in to comment.