Skip to content

Commit

Permalink
Merge pull request #932 from venpopov/expose_new_stan_args
Browse files Browse the repository at this point in the history
Expose new stan args
  • Loading branch information
andrjohns authored May 4, 2024
2 parents cc2e36d + 6d7ee0e commit 15aa9d9
Show file tree
Hide file tree
Showing 19 changed files with 600 additions and 277 deletions.
35 changes: 30 additions & 5 deletions R/args.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ CmdStanArgs <- R6::R6Class(
sig_figs = NULL,
opencl_ids = NULL,
model_variables = NULL,
num_threads = NULL) {
num_threads = NULL,
save_cmdstan_config = NULL) {

self$model_name <- model_name
self$stan_code <- stan_code
Expand All @@ -60,6 +61,7 @@ CmdStanArgs <- R6::R6Class(
self$save_latent_dynamics <- save_latent_dynamics
self$using_tempdir <- is.null(output_dir)
self$model_variables <- model_variables
self$save_cmdstan_config <- save_cmdstan_config
if (os_is_wsl()) {
# Want to ensure that any files under WSL are written to a tempdir within
# WSL to avoid IO performance issues
Expand Down Expand Up @@ -87,6 +89,9 @@ CmdStanArgs <- R6::R6Class(
self$opencl_ids <- opencl_ids
self$num_threads = NULL
self$method_args$validate(num_procs = length(self$proc_ids))
if (is.logical(self$save_cmdstan_config)) {
self$save_cmdstan_config <- as.integer(self$save_cmdstan_config)
}
self$validate()
},
validate = function() {
Expand All @@ -111,7 +116,7 @@ CmdStanArgs <- R6::R6Class(
} else if (type == "profile") {
basename <- paste0(basename, "-profile")
}
if (type == "output" && !is.null(self$output_basename)) {
if (type == "output" && !is.null(self$output_basename)) {
basename <- self$output_basename
}
generate_file_names(
Expand Down Expand Up @@ -180,6 +185,9 @@ CmdStanArgs <- R6::R6Class(
if (!is.null(profile_file)) {
args$output <- c(args$output, paste0("profile_file=", wsl_safe_path(profile_file)))
}
if (!is.null(self$save_cmdstan_config)) {
args$output <- c(args$output, paste0("save_cmdstan_config=", self$save_cmdstan_config))
}
if (!is.null(self$opencl_ids)) {
args$opencl <- c("opencl", paste0("platform=", self$opencl_ids[1]), paste0("device=", self$opencl_ids[2]))
}
Expand Down Expand Up @@ -218,7 +226,8 @@ SampleArgs <- R6::R6Class(
term_buffer = NULL,
window = NULL,
fixed_param = FALSE,
diagnostics = NULL) {
diagnostics = NULL,
save_metric = NULL) {

self$iter_warmup <- iter_warmup
self$iter_sampling <- iter_sampling
Expand All @@ -232,6 +241,7 @@ SampleArgs <- R6::R6Class(
self$inv_metric <- inv_metric
self$fixed_param <- fixed_param
self$diagnostics <- diagnostics
self$save_metric <- save_metric
if (identical(self$diagnostics, "")) {
self$diagnostics <- NULL
}
Expand Down Expand Up @@ -275,6 +285,9 @@ SampleArgs <- R6::R6Class(
if (is.logical(self$save_warmup)) {
self$save_warmup <- as.integer(self$save_warmup)
}
if (is.logical(self$save_metric)) {
self$save_metric <- as.integer(self$save_metric)
}
invisible(self)
},
validate = function(num_procs) {
Expand Down Expand Up @@ -314,7 +327,8 @@ SampleArgs <- R6::R6Class(
.make_arg("adapt_engaged"),
.make_arg("init_buffer"),
.make_arg("term_buffer"),
.make_arg("window")
.make_arg("window"),
.make_arg("save_metric")
)
} else {
new_args <- list(
Expand All @@ -335,7 +349,8 @@ SampleArgs <- R6::R6Class(
.make_arg("adapt_engaged"),
.make_arg("init_buffer"),
.make_arg("term_buffer"),
.make_arg("window")
.make_arg("window"),
.make_arg("save_metric")
)
}
new_args <- do.call(c, new_args)
Expand Down Expand Up @@ -682,6 +697,7 @@ validate_cmdstan_args <- function(self) {
checkmate::assert_flag(self$save_latent_dynamics)
checkmate::assert_integerish(self$refresh, lower = 0, null.ok = TRUE)
checkmate::assert_integerish(self$sig_figs, lower = 1, upper = 18, null.ok = TRUE)
checkmate::assert_integerish(self$save_cmdstan_config, lower = 0, upper = 1, len = 1, null.ok = TRUE)
if (!is.null(self$sig_figs) && cmdstan_version() < "2.25") {
warning("The 'sig_figs' argument is only supported with cmdstan 2.25+ and will be ignored!", call. = FALSE)
}
Expand Down Expand Up @@ -799,6 +815,15 @@ validate_sample_args <- function(self, num_procs) {
checkmate::assert_subset(self$diagnostics, empty.ok = FALSE, choices = available_hmc_diagnostics())
}

checkmate::assert_integerish(self$save_metric,
lower = 0, upper = 1,
len = 1,
null.ok = TRUE)

if (is.null(self$adapt_engaged) || (!self$adapt_engaged && !is.null(self$save_metric))) {
self$save_metric <- 0
}

invisible(TRUE)
}

Expand Down
50 changes: 46 additions & 4 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -898,10 +898,13 @@ CmdStanFit$set("public", name = "cmdstan_diagnose", value = cmdstan_diagnose)
#' Save output and data files
#'
#' @name fit-method-save_output_files
#' @aliases fit-method-save_data_file fit-method-save_latent_dynamics_files fit-method-save_profile_files
#' fit-method-output_files fit-method-data_file fit-method-latent_dynamics_files fit-method-profile_files
#' save_output_files save_data_file save_latent_dynamics_files save_profile_files
#' output_files data_file latent_dynamics_files profile_files
#' @aliases fit-method-save_data_file fit-method-save_latent_dynamics_files
#' fit-method-save_profile_files fit-method-output_files fit-method-data_file
#' fit-method-latent_dynamics_files fit-method-profile_files
#' fit-method-save_config_files fit-method-save_metric_files save_output_files
#' save_data_file save_latent_dynamics_files save_profile_files
#' save_config_files save_metric_files output_files data_file
#' latent_dynamics_files profile_files config_files metric_files
#'
#' @description All fitted model objects have methods for saving (moving to a
#' specified location) the files created by CmdStanR to hold CmdStan output
Expand Down Expand Up @@ -936,6 +939,14 @@ CmdStanFit$set("public", name = "cmdstan_diagnose", value = cmdstan_diagnose)
#' `$save_output_files()` except `"-profile-"` is included in the new
#' file name after `basename`.
#'
#' For `$save_metric_files()` everything is the same as for
#' `$save_output_files()` except `"-metric-"` is included in the new
#' file name after `basename`.
#'
#' For `$save_config_files()` everything is the same as for
#' `$save_output_files()` except `"-config-"` is included in the new
#' file name after `basename`.
#'
#' For `$save_data_file()` no `id` is included in the file name because even
#' with multiple MCMC chains the data file is the same.
#'
Expand Down Expand Up @@ -998,6 +1009,26 @@ save_data_file <- function(dir = ".",
}
CmdStanFit$set("public", name = "save_data_file", value = save_data_file)

#' @rdname fit-method-save_output_files
save_config_files <- function(dir = ".",
basename = NULL,
timestamp = TRUE,
random = TRUE) {
self$runset$save_config_files(dir, basename, timestamp, random)
}
CmdStanFit$set("public", name = "save_config_files", value = save_config_files)

#' @rdname fit-method-save_output_files
save_metric_files <- function(dir = ".",
basename = NULL,
timestamp = TRUE,
random = TRUE) {
self$runset$save_metric_files(dir, basename, timestamp, random)
}
CmdStanFit$set("public", name = "save_metric_files", value = save_metric_files)



#' @rdname fit-method-save_output_files
#' @param include_failed (logical) Should CmdStan runs that failed also be
#' included? The default is `FALSE.`
Expand All @@ -1024,6 +1055,17 @@ data_file <- function() {
}
CmdStanFit$set("public", name = "data_file", value = data_file)

#' @rdname fit-method-save_output_files
config_files <- function(include_failed = FALSE) {
self$runset$config_files(include_failed)
}
CmdStanFit$set("public", name = "config_files", value = config_files)

#' @rdname fit-method-save_output_files
metric_files <- function(include_failed = FALSE) {
self$runset$metric_files(include_failed)
}
CmdStanFit$set("public", name = "metric_files", value = metric_files)

#' Report timing of CmdStan runs
#'
Expand Down
36 changes: 25 additions & 11 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,8 @@ sample <- function(data = NULL,
show_messages = TRUE,
show_exceptions = TRUE,
diagnostics = c("divergences", "treedepth", "ebfmi"),
save_metric = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL },
save_cmdstan_config = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL },
# deprecated
cores = NULL,
num_cores = NULL,
Expand Down Expand Up @@ -1240,7 +1242,8 @@ sample <- function(data = NULL,
term_buffer = term_buffer,
window = window,
fixed_param = fixed_param,
diagnostics = diagnostics
diagnostics = diagnostics,
save_metric = save_metric
)
args <- CmdStanArgs$new(
method_args = sample_args,
Expand All @@ -1260,7 +1263,8 @@ sample <- function(data = NULL,
output_basename = output_basename,
sig_figs = sig_figs,
opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()),
model_variables = model_variables
model_variables = model_variables,
save_cmdstan_config = save_cmdstan_config
)
runset <- CmdStanRun$new(args, procs)
runset$run_cmdstan()
Expand Down Expand Up @@ -1357,6 +1361,7 @@ sample_mpi <- function(data = NULL,
show_messages = TRUE,
show_exceptions = TRUE,
diagnostics = c("divergences", "treedepth", "ebfmi"),
save_cmdstan_config = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL },
# deprecated
validate_csv = TRUE) {

Expand Down Expand Up @@ -1420,7 +1425,8 @@ sample_mpi <- function(data = NULL,
output_dir = output_dir,
output_basename = output_basename,
sig_figs = sig_figs,
model_variables = model_variables
model_variables = model_variables,
save_cmdstan_config = save_cmdstan_config
)
runset <- CmdStanRun$new(args, procs)
runset$run_cmdstan_mpi(mpi_cmd, mpi_args)
Expand Down Expand Up @@ -1500,7 +1506,8 @@ optimize <- function(data = NULL,
tol_param = NULL,
history_size = NULL,
show_messages = TRUE,
show_exceptions = TRUE) {
show_exceptions = TRUE,
save_cmdstan_config = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL }) {
procs <- CmdStanProcs$new(
num_procs = 1,
show_stderr_messages = show_exceptions,
Expand Down Expand Up @@ -1541,7 +1548,8 @@ optimize <- function(data = NULL,
output_basename = output_basename,
sig_figs = sig_figs,
opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()),
model_variables = model_variables
model_variables = model_variables,
save_cmdstan_config = save_cmdstan_config
)
runset <- CmdStanRun$new(args, procs)
runset$run_cmdstan()
Expand Down Expand Up @@ -1632,7 +1640,8 @@ laplace <- function(data = NULL,
jacobian = TRUE, # different default than for optimize!
draws = NULL,
show_messages = TRUE,
show_exceptions = TRUE) {
show_exceptions = TRUE,
save_cmdstan_config = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL }) {
if (cmdstan_version() < "2.32") {
stop("This method is only available in cmdstan >= 2.32", call. = FALSE)
}
Expand Down Expand Up @@ -1706,7 +1715,8 @@ laplace <- function(data = NULL,
output_basename = output_basename,
sig_figs = sig_figs,
opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()),
model_variables = model_variables
model_variables = model_variables,
save_cmdstan_config = save_cmdstan_config
)
runset <- CmdStanRun$new(args, procs)
runset$run_cmdstan()
Expand Down Expand Up @@ -1786,7 +1796,8 @@ variational <- function(data = NULL,
output_samples = NULL,
draws = NULL,
show_messages = TRUE,
show_exceptions = TRUE) {
show_exceptions = TRUE,
save_cmdstan_config = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL }) {
procs <- CmdStanProcs$new(
num_procs = 1,
show_stderr_messages = show_exceptions,
Expand Down Expand Up @@ -1827,7 +1838,8 @@ variational <- function(data = NULL,
output_basename = output_basename,
sig_figs = sig_figs,
opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()),
model_variables = model_variables
model_variables = model_variables,
save_cmdstan_config = save_cmdstan_config
)
runset <- CmdStanRun$new(args, procs)
runset$run_cmdstan()
Expand Down Expand Up @@ -1929,7 +1941,8 @@ pathfinder <- function(data = NULL,
psis_resample = NULL,
calculate_lp = NULL,
show_messages = TRUE,
show_exceptions = TRUE) {
show_exceptions = TRUE,
save_cmdstan_config = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL }) {
procs <- CmdStanProcs$new(
num_procs = 1,
show_stderr_messages = show_exceptions,
Expand Down Expand Up @@ -1976,7 +1989,8 @@ pathfinder <- function(data = NULL,
sig_figs = sig_figs,
opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()),
model_variables = model_variables,
num_threads = num_threads
num_threads = num_threads,
save_cmdstan_config = save_cmdstan_config
)
runset <- CmdStanRun$new(args, procs)
runset$run_cmdstan()
Expand Down
Loading

0 comments on commit 15aa9d9

Please sign in to comment.