Skip to content

Commit

Permalink
Fix spurious cmdstan config errors (#981)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrjohns authored May 23, 2024
1 parent 91e4bf6 commit 499aa23
Show file tree
Hide file tree
Showing 9 changed files with 35 additions and 77 deletions.
14 changes: 7 additions & 7 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -1167,8 +1167,8 @@ sample <- function(data = NULL,
show_messages = TRUE,
show_exceptions = TRUE,
diagnostics = c("divergences", "treedepth", "ebfmi"),
save_metric = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL },
save_cmdstan_config = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL },
save_metric = NULL,
save_cmdstan_config = NULL,
# deprecated
cores = NULL,
num_cores = NULL,
Expand Down Expand Up @@ -1379,7 +1379,7 @@ sample_mpi <- function(data = NULL,
show_messages = TRUE,
show_exceptions = TRUE,
diagnostics = c("divergences", "treedepth", "ebfmi"),
save_cmdstan_config = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL },
save_cmdstan_config = NULL,
# deprecated
validate_csv = TRUE) {

Expand Down Expand Up @@ -1525,7 +1525,7 @@ optimize <- function(data = NULL,
history_size = NULL,
show_messages = TRUE,
show_exceptions = TRUE,
save_cmdstan_config = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL }) {
save_cmdstan_config = NULL) {
procs <- CmdStanProcs$new(
num_procs = 1,
show_stderr_messages = show_exceptions,
Expand Down Expand Up @@ -1659,7 +1659,7 @@ laplace <- function(data = NULL,
draws = NULL,
show_messages = TRUE,
show_exceptions = TRUE,
save_cmdstan_config = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL }) {
save_cmdstan_config = NULL) {
if (cmdstan_version() < "2.32") {
stop("This method is only available in cmdstan >= 2.32", call. = FALSE)
}
Expand Down Expand Up @@ -1815,7 +1815,7 @@ variational <- function(data = NULL,
draws = NULL,
show_messages = TRUE,
show_exceptions = TRUE,
save_cmdstan_config = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL }) {
save_cmdstan_config = NULL) {
procs <- CmdStanProcs$new(
num_procs = 1,
show_stderr_messages = show_exceptions,
Expand Down Expand Up @@ -1960,7 +1960,7 @@ pathfinder <- function(data = NULL,
calculate_lp = NULL,
show_messages = TRUE,
show_exceptions = TRUE,
save_cmdstan_config = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL }) {
save_cmdstan_config = NULL) {
procs <- CmdStanProcs$new(
num_procs = 1,
show_stderr_messages = show_exceptions,
Expand Down
22 changes: 4 additions & 18 deletions R/run.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ CmdStanRun <- R6::R6Class(
if (cmdstan_version() >= "2.26.0") {
private$profile_files_ <- self$new_profile_files()
}
if (cmdstan_version() >= "2.34.0" && !is.null(self$args$save_cmdstan_config) && self$args$save_cmdstan_config) {
if (cmdstan_version() >= "2.34.0" && !is.null(self$args$save_cmdstan_config) && as.logical(self$args$save_cmdstan_config)) {
private$config_files_ <- self$new_config_files()
}
if (cmdstan_version() >= "2.34.0" && !is.null(self$args$method_args$save_metric) && self$args$method_args$save_metric) {
if (cmdstan_version() >= "2.34.0" && !is.null(self$args$method_args$save_metric) && as.logical(self$args$method_args$save_metric)) {
private$metric_files_ <- self$new_metric_files()
}
if (self$args$save_latent_dynamics) {
Expand Down Expand Up @@ -77,13 +77,6 @@ CmdStanRun <- R6::R6Class(
config_files = function(include_failed = FALSE) {
files <- private$config_files_
files_win_path <- sapply(private$config_files_, wsl_safe_path, revert = TRUE)
if (!length(files) || !any(file.exists(files_win_path))) {
stop(
"No CmdStan config files found. ",
"Set 'save_cmdstan_config=TRUE' when fitting the model.",
call. = FALSE
)
}
if (include_failed) {
files
} else {
Expand All @@ -94,13 +87,6 @@ CmdStanRun <- R6::R6Class(
metric_files = function(include_failed = FALSE) {
files <- private$metric_files_
files_win_path <- sapply(private$metric_files_, wsl_safe_path, revert = TRUE)
if (!length(files) || !any(file.exists(files_win_path))) {
stop(
"No metric files found. ",
"Set 'save_metric=TRUE' when fitting the model.",
call. = FALSE
)
}
if (include_failed) {
files
} else {
Expand Down Expand Up @@ -404,12 +390,12 @@ CmdStanRun <- R6::R6Class(
private$profile_files_,
if (cmdstan_version() > "2.34.0" &&
!is.null(self$args$save_cmdstan_config) &&
self$args$save_cmdstan_config &&
as.logical(self$args$save_cmdstan_config) &&
!private$config_files_saved_)
self$config_files(include_failed = TRUE),
if (cmdstan_version() > "2.34.0" &&
!(is.null(self$args$method_args$save_metric)) &&
self$args$method_args$save_metric &&
as.logical(self$args$method_args$save_metric) &&
!private$metric_files_saved_)
self$metric_files(include_failed = TRUE)
)
Expand Down
7 changes: 1 addition & 6 deletions man/model-method-laplace.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 1 addition & 6 deletions man/model-method-optimize.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 1 addition & 6 deletions man/model-method-pathfinder.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 2 additions & 11 deletions man/model-method-sample.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 1 addition & 6 deletions man/model-method-sample_mpi.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 1 addition & 6 deletions man/model-method-variational.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 17 additions & 11 deletions tests/testthat/test-model-output_dir.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,23 @@ test_that("all fitting methods work with output_dir", {
files <- list.files(method_dir)
}
# specifying output_dir
fit <- testing_fit("bernoulli", method = method, seed = 123,
output_dir = method_dir)
call_args <- list(
"bernoulli",
method = method,
seed = 123,
output_dir = method_dir,
save_cmdstan_config = TRUE
)
if (method == "sample") {
call_args$save_metric <- TRUE
}
fit <- do.call(testing_fit, call_args)
# WSL path manipulations result in a short path which slightly differs
# from the original tempdir(), so need to normalise both for comparison
expect_equal(normalizePath(fit$runset$args$output_dir),
normalizePath(method_dir))
files <- normalizePath(list.files(method_dir, full.names = TRUE))
# in 2.34.0 we also save the config files for all methods and the metric
# for sample
if (cmdstan_version() < "2.34.0") {
mult <- 1
} else if (method == "sample") {
if (method == "sample") {
mult <- 3
expect_equal(files[grepl("metric", files)],
normalizePath(sapply(fit$metric_files(), wsl_safe_path, revert = TRUE,
Expand Down Expand Up @@ -99,7 +104,10 @@ test_that("error if output_dir is invalid", {
})

test_that("output_dir works with trailing /", {
test_dir <- file.path(sandbox, "trailing")
test_dir <- file.path(tempdir(check = TRUE), "output_dir")
if (dir.exists(test_dir)) {
unlink(test_dir, recursive = TRUE)
}
dir.create(test_dir)
fit <- testing_fit(
"bernoulli",
Expand All @@ -109,7 +117,5 @@ test_that("output_dir works with trailing /", {
)
expect_equal(normalizePath(fit$runset$args$output_dir),
normalizePath(test_dir))
# in 2.34.0 we also save the metric and config files
mult <- if (cmdstan_version() >= "2.34.0") 3 else 1
expect_equal(length(list.files(test_dir)), mult * fit$num_procs())
expect_equal(length(list.files(test_dir)), fit$num_procs())
})

0 comments on commit 499aa23

Please sign in to comment.