diff --git a/R/model.R b/R/model.R index 2a7b4623f..5fe9f70fc 100644 --- a/R/model.R +++ b/R/model.R @@ -230,7 +230,7 @@ CmdStanModel <- R6::R6Class( self$functions <- new.env() self$functions$compiled <- FALSE if (!is.null(stan_file)) { - assert_file_exists(stan_file, access = "r", extension = "stan") + assert_file_exists(stan_file, access = "r", extension = c("stan", "stanfunctions")) checkmate::assert_flag(compile) private$stan_file_ <- absolute_path(stan_file) private$stan_code_ <- readLines(stan_file) @@ -537,7 +537,7 @@ compile <- function(quiet = TRUE, compile_hessian_method <- FALSE } - temp_stan_file <- tempfile(pattern = "model-", fileext = ".stan") + temp_stan_file <- tempfile(pattern = "model-", fileext = paste0(".", tools::file_ext(self$stan_file()))) file.copy(self$stan_file(), temp_stan_file, overwrite = TRUE) temp_file_no_ext <- strip_ext(temp_stan_file) tmp_exe <- cmdstan_ext(temp_file_no_ext) # adds .exe on Windows diff --git a/R/utils.R b/R/utils.R index 96c7739cf..d270cef7f 100644 --- a/R/utils.R +++ b/R/utils.R @@ -748,7 +748,8 @@ expose_model_methods <- function(env, verbose = FALSE, hessian = FALSE) { package = "cmdstanr", mustWork = TRUE))) if (hessian) { - code <- c(code, + code <- c("#include ", + code, readLines(system.file("include", "hessian.cpp", package = "cmdstanr", mustWork = TRUE))) } @@ -758,9 +759,8 @@ expose_model_methods <- function(env, verbose = FALSE, hessian = FALSE) { invisible(NULL) } -initialize_model_pointer <- function(env, data, seed = 0) { - datafile_path <- ifelse(is.null(data), "", data) - ptr_and_rng <- env$model_ptr(datafile_path, seed) +initialize_model_pointer <- function(env, datafile_path, seed = 0) { + ptr_and_rng <- env$model_ptr(ifelse(is.null(datafile_path), "", datafile_path), seed) env$model_ptr_ <- ptr_and_rng$model_ptr env$model_rng_ <- ptr_and_rng$base_rng env$num_upars_ <- env$get_num_upars(env$model_ptr_) @@ -863,8 +863,8 @@ prep_fun_cpp <- function(fun_start, fun_end, model_lines) { fun_body <- gsub("auto", get_plain_rtn(fun_start, fun_end, model_lines), fun_body) fun_body <- gsub("// [[stan::function]]", "// [[Rcpp::export]]\n", fun_body, fixed = TRUE) fun_body <- gsub("std::ostream\\*\\s*pstream__\\s*=\\s*nullptr", "", fun_body) - fun_body <- gsub("boost::ecuyer1988& base_rng__", "size_t seed = 0", fun_body, fixed = TRUE) - fun_body <- gsub("base_rng__,", "*(new boost::ecuyer1988(seed)),", fun_body, fixed = TRUE) + fun_body <- gsub("boost::ecuyer1988&\\s*base_rng__", "SEXP base_rng_ptr", fun_body) + fun_body <- gsub("base_rng__,", "*(Rcpp::XPtr(base_rng_ptr).get()),", fun_body, fixed = TRUE) fun_body <- gsub("pstream__", "&Rcpp::Rcout", fun_body, fixed = TRUE) fun_body <- paste(fun_body, collapse = "\n") gsub(pattern = ",\\s*)", replacement = ")", fun_body) @@ -904,6 +904,30 @@ compile_functions <- function(env, verbose = FALSE, global = FALSE) { } else { rcpp_source_stan(mod_stan_funs, env, verbose) } + + # If an RNG function is exposed, initialise a Boost RNG object stored in the + # environment + rng_funs <- grep("rng\\b", env$fun_names, value = TRUE) + if (length(rng_funs) > 0) { + rng_cpp <- system.file("include", "base_rng.cpp", package = "cmdstanr", mustWork = TRUE) + rcpp_source_stan(paste0(readLines(rng_cpp), collapse="\n"), env, verbose) + env$rng_ptr <- env$base_rng(seed=0) + } + + # For all RNG functions, pass the initialised Boost RNG by default + for (fun in rng_funs) { + if (global) { + fun_env <- globalenv() + } else { + fun_env <- env + } + fundef <- get(fun, envir = fun_env) + funargs <- formals(fundef) + funargs$base_rng_ptr <- env$rng_ptr + formals(fundef) <- funargs + assign(fun, fundef, envir = fun_env) + } + env$compiled <- TRUE invisible(NULL) } diff --git a/inst/include/base_rng.cpp b/inst/include/base_rng.cpp new file mode 100644 index 000000000..7b38ba16f --- /dev/null +++ b/inst/include/base_rng.cpp @@ -0,0 +1,8 @@ +#include +#include + +// [[Rcpp::export]] +SEXP base_rng(boost::uint32_t seed = 0) { + Rcpp::XPtr rng_ptr(new boost::ecuyer1988(seed)); + return rng_ptr; +} diff --git a/tests/testthat/test-model-expose-functions.R b/tests/testthat/test-model-expose-functions.R index 9836d4065..e1e99d1bb 100644 --- a/tests/testthat/test-model-expose-functions.R +++ b/tests/testthat/test-model-expose-functions.R @@ -112,7 +112,7 @@ test_that("Functions can be compiled with model", { test_that("rng functions can be exposed", { skip_if(os_is_wsl()) - function_decl <- "functions { real normal_rng(real mu) { return normal_rng(mu, 1); } }" + function_decl <- "functions { real wrap_normal_rng(real mu, real sigma) { return normal_rng(mu, sigma); } }" stan_prog <- paste(function_decl, paste(readLines(testing_stan_file("bernoulli")), collapse = "\n"), @@ -122,11 +122,17 @@ test_that("rng functions can be exposed", { mod <- cmdstan_model(model, force_recompile = TRUE) fit <- mod$sample(data = data_list) + set.seed(10) fit$expose_functions(verbose = TRUE) expect_equal( - fit$functions$normal_rng(5, seed = 10), - 3.8269637967017344771 + fit$functions$wrap_normal_rng(5,10), + -4.5298764235381225873 + ) + + expect_equal( + fit$functions$wrap_normal_rng(5,10), + 8.1295902610102039887 ) })