Skip to content

Commit

Permalink
adds feature to make inits from fit and draws objects
Browse files Browse the repository at this point in the history
  • Loading branch information
SteveBronder committed Mar 21, 2024
1 parent ae1b7b3 commit ae46447
Show file tree
Hide file tree
Showing 17 changed files with 318 additions and 28 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ URL: https://mc-stan.org/cmdstanr/, https://discourse.mc-stan.org
BugReports: https://github.com/stan-dev/cmdstanr/issues
Encoding: UTF-8
LazyData: true
RoxygenNote: 7.3.0
RoxygenNote: 7.3.1
Roxygen: list(markdown = TRUE, r6 = FALSE)
SystemRequirements: CmdStan (https://mc-stan.org/users/interfaces/cmdstan)
Depends:
Expand Down
230 changes: 218 additions & 12 deletions R/args.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,12 @@ CmdStanArgs <- R6::R6Class(
}
self$output_dir <- repair_path(self$output_dir)
self$output_basename <- output_basename
if (is.function(init)) {
init <- process_init_function(init, length(self$proc_ids), model_variables)
} else if (is.list(init) && !is.data.frame(init)) {
init <- process_init_list(init, length(self$proc_ids), model_variables)
if (inherits(self$method_args, "PathfinderArgs")) {
num_inits <- self$method_args$num_paths
} else {
num_inits <- length(self$proc_ids)
}
init <- process_init(init, num_inits, model_variables)
self$init <- init
self$opencl_ids <- opencl_ids
self$num_threads = NULL
Expand Down Expand Up @@ -691,7 +692,12 @@ validate_cmdstan_args <- function(self) {
assert_file_exists(self$data_file, access = "r")
}
num_procs <- length(self$proc_ids)
validate_init(self$init, num_procs)
if (inherits(self$method_args, "PathfinderArgs")) {
num_inits <- self$method_args$num_paths
} else {
num_inits <- length(self$proc_ids)
}
validate_init(self$init, num_inits)
validate_seed(self$seed, num_procs)
if (!is.null(self$opencl_ids)) {
if (cmdstan_version() < "2.26") {
Expand Down Expand Up @@ -1018,17 +1024,63 @@ validate_exe_file <- function(exe_file) {
invisible(TRUE)
}


#' Generic for processing inits
#' @noRd
process_init <- function(...) {
UseMethod("process_init")
}

#' Default method
#' @noRd
process_init.default <- function(x, ...) {
return(x)
}

#' Write initial values to files if provided as posterior `draws` object
#' @noRd
#' @param init A type that inherits the `posterior::draws` class.
#' @param num_procs Number of inits requested
#' @param model_variables A list of all parameters with their types and
#' number of dimensions. Typically the output of model$variables().
#' @param warn_partial Should a warning be thrown if inits are only specified
#' for a subset of parameters? Can be controlled by global option
#' `cmdstanr_warn_inits`.
#' @return A character vector of file paths.
process_init.draws <- function(init, num_procs, model_variables = NULL,
warn_partial = getOption("cmdstanr_warn_inits", TRUE)) {
if (!is.null(model_variables)) {
variable_names = names(model_variables$parameters)
} else {
variable_names = colnames(draws)[!grepl("__", colnames(draws))]
}
draws <- posterior::subset_draws(init, variable = variable_names)
draws <- posterior::resample_draws(draws, ndraws = num_procs,
method ="simple_no_replace")
draws_rvar = posterior::as_draws_rvars(draws)
inits = lapply(1:num_procs, \(draw_iter) {
init_i = lapply(variable_names, \(var_name) {
x = drop(posterior::draws_of(drop(
posterior::subset_draws(draws_rvar[[var_name]], draw=draw_iter))))
return(x)
})
names(init_i) = variable_names
return(init_i)
})
return(process_init(inits, num_procs, model_variables, warn_partial))
}

#' Write initial values to files if provided as list of lists
#' @noRd
#' @param init List of init lists.
#' @param num_procs Number of CmdStan processes.
#' @param num_procs Number of inits needed.
#' @param model_variables A list of all parameters with their types and
#' number of dimensions. Typically the output of model$variables().
#' @param warn_partial Should a warning be thrown if inits are only specified
#' for a subset of parameters? Can be controlled by global option
#' `cmdstanr_warn_inits`.
#' @return A character vector of file paths.
process_init_list <- function(init, num_procs, model_variables = NULL,
process_init.list <- function(init, num_procs, model_variables = NULL,
warn_partial = getOption("cmdstanr_warn_inits", TRUE)) {
if (!all(sapply(init, function(x) is.list(x) && !is.data.frame(x)))) {
stop("If 'init' is a list it must be a list of lists.", call. = FALSE)
Expand Down Expand Up @@ -1083,10 +1135,11 @@ process_init_list <- function(init, num_procs, model_variables = NULL,
}
init_paths <-
tempfile(
pattern = paste0("init-", seq_along(init), "-"),
pattern = "init-",
tmpdir = cmdstan_tempdir(),
fileext = ".json"
fileext = ""
)
init_paths <- paste0(init_paths, "_", seq_along(init), ".json")
for (i in seq_along(init)) {
write_stan_json(init[[i]], init_paths[i])
}
Expand All @@ -1096,11 +1149,12 @@ process_init_list <- function(init, num_procs, model_variables = NULL,
#' Write initial values to files if provided as function
#' @noRd
#' @param init Function generating a single list of initial values.
#' @param num_procs Number of CmdStan processes.
#' @param num_procs Number of inits needed.
#' @param model_variables A list of all parameters with their types and
#' number of dimensions. Typically the output of model$variables().
#' @return A character vector of file paths.
process_init_function <- function(init, num_procs, model_variables = NULL) {
process_init.function <- function(init, num_procs, model_variables = NULL,
warn_partial = getOption("cmdstanr_warn_inits", TRUE)) {
args <- formals(init)
if (is.null(args)) {
fn_test <- init()
Expand All @@ -1116,7 +1170,159 @@ process_init_function <- function(init, num_procs, model_variables = NULL) {
if (!is.list(fn_test) || is.data.frame(fn_test)) {
stop("If 'init' is a function it must return a single list.")
}
process_init_list(init_list, num_procs, model_variables)
process_init(init_list, num_procs, model_variables)
}

#' Write initial values to files if provided as a `CmdStanMCMC` class
#' @noRd
#' @param init A `CmdStanMCMC` class
#' @param num_procs Number of inits requested
#' @param model_variables A list of all parameters with their types and
#' number of dimensions. Typically the output of model$variables().
#' @param warn_partial Should a warning be thrown if inits are only specified
#' for a subset of parameters? Can be controlled by global option
#' `cmdstanr_warn_inits`.
#' @return A character vector of file paths.
process_init.CmdStanMCMC <- function(init, num_procs, model_variables = NULL,
warn_partial = getOption("cmdstanr_warn_inits", TRUE)) {
# Convert from data.table to data.frame
if (all(init$return_codes() == 1)) {
stop("We are unable to create initial values from a model with no samples. Please check the results of the model used for inits before continuing.")
} else if (!any(names(model_variables$parameters) %in% init$metadata()$stan_variables)) {
stop("None of the names of the parameters for the model used for initial values match the names of parameters from the model currently running.")
}
draws_df = init$draws(format = "df")
if (is.null(model_variables)) {
model_variables = list(parameters = colnames(draws_df)[2:(length(colnames(draws_df)) - 3)])
}
init_draws_df = posterior::resample_draws(draws_df, ndraws = num_procs,
method = "simple_no_replace")
init_draws_lst = process_init(init_draws_df,
num_procs = num_procs, model_variables = model_variables)
return(init_draws_lst)
}

#' Performs PSIS resampling on the draws from an approxmation method for inits.
#' @noRd
#' @param init A set of draws with `lp__` and `lp_approx__` columns.
#' @param num_procs Number of inits requested
#' @param model_variables A list of all parameters with their types and
#' number of dimensions. Typically the output of model$variables().
#' @param warn_partial Should a warning be thrown if inits are only specified
#' for a subset of parameters? Can be controlled by global option
#' `cmdstanr_warn_inits`.
#' @return A character vector of file paths.
process_init_approx <- function(init, num_procs, model_variables = NULL,
warn_partial = getOption("cmdstanr_warn_inits", TRUE)) {
# Convert from data.table to data.frame
if (init$return_codes() == 1) {
stop("We are unable to create initial values from a model with no samples. Please check the results of the model used for inits before continuing.")
} else if (!any(names(model_variables$parameters) %in% init$metadata()$stan_variables)) {
stop("None of the names of the parameters for the model used for initial values match the names of parameters from the model currently running.")
}
draws_df = init$draws(format = "df")
if (is.null(model_variables)) {
model_variables = list(parameters = colnames(draws_df)[3:(length(colnames(draws_df)) - 3)])
}
draws_df$lw = draws_df$lp__ - draws_df$lp_approx__
# Calculate unique draws based on 'lw' using base R functions
unique_draws = length(unique(draws_df$lw))
if (num_procs > unique_draws) {
if (inherits(init, "CmdStanPathfinder")) {
stop(paste0("Not enough distinct draws (", num_procs, ") in pathfinder fit to create inits. Try running Pathfinder with psis_resample=FALSE"))
} else {
stop(paste0("Not enough distinct draws (", num_procs, ") to create inits."))
}
}
if (unique_draws < (0.95 * nrow(draws_df))) {
temp_df = aggregate(.draw ~ lw, data = draws_df, FUN = min)
draws_df = posterior::as_draws_df(merge(temp_df, draws_df, by = 'lw'))
draws_df$pareto_weight = exp(draws_df$lw - max(draws_df$lw))
} else {
draws_df$pareto_weight = posterior::pareto_smooth(
exp(draws_df$lw - max(draws_df$lw)), tail = "right")[["x"]]
}
init_draws_df = posterior::resample_draws(draws_df, ndraws = num_procs,
weights = draws_df$pareto_weight, method = "simple_no_replace")
init_draws_lst = process_init(init_draws_df,
num_procs = num_procs, model_variables = model_variables, warn_partial)
return(init_draws_lst)
}


#' Write initial values to files if provided as a `CmdStanPathfinder` class
#' @noRd
#' @param init A `CmdStanPathfinder` class
#' @param num_procs Number of inits requested
#' @param model_variables A list of all parameters with their types and
#' number of dimensions. Typically the output of model$variables().
#' @param warn_partial Should a warning be thrown if inits are only specified
#' for a subset of parameters? Can be controlled by global option
#' `cmdstanr_warn_inits`.
#' @return A character vector of file paths.
process_init.CmdStanPathfinder <- function(init, num_procs, model_variables = NULL,
warn_partial = getOption("cmdstanr_warn_inits", TRUE)) {
process_init_approx(init, num_procs, model_variables, warn_partial)
}

#' Write initial values to files if provided as a `CmdStanVB` class
#' @noRd
#' @param init A `CmdStanVB` class
#' @param num_procs Number of inits requested
#' @param model_variables A list of all parameters with their types and
#' number of dimensions. Typically the output of model$variables().
#' @param warn_partial Should a warning be thrown if inits are only specified
#' for a subset of parameters? Can be controlled by global option
#' `cmdstanr_warn_inits`.
#' @return A character vector of file paths.
process_init.CmdStanVB <- function(init, num_procs, model_variables = NULL,
warn_partial = getOption("cmdstanr_warn_inits", TRUE)) {
process_init_approx(init, num_procs, model_variables, warn_partial)
}

#' Write initial values to files if provided as a `CmdStanLaplace` class
#' @noRd
#' @param init A `CmdStanLaplace` class
#' @param num_procs Number of inits requested
#' @param model_variables A list of all parameters with their types and
#' number of dimensions. Typically the output of model$variables().
#' @param warn_partial Should a warning be thrown if inits are only specified
#' for a subset of parameters? Can be controlled by global option
#' `cmdstanr_warn_inits`.
#' @return A character vector of file paths.
process_init.CmdStanLaplace <- function(init, num_procs, model_variables = NULL,
warn_partial = getOption("cmdstanr_warn_inits", TRUE)) {
process_init_approx(init, num_procs, model_variables, warn_partial)
}


#' Write initial values to files if provided as a `CmdStanMLE` class
#' @noRd
#' @param init A `CmdStanMLE` class
#' @param num_procs Number of inits requested
#' @param model_variables A list of all parameters with their types and
#' number of dimensions. Typically the output of model$variables().
#' @param warn_partial Should a warning be thrown if inits are only specified
#' for a subset of parameters? Can be controlled by global option
#' `cmdstanr_warn_inits`.
#' @return A character vector of file paths.
process_init.CmdStanMLE <- function(init, num_procs, model_variables = NULL,
warn_partial = getOption("cmdstanr_warn_inits", TRUE)) {
# Convert from data.table to data.frame
if (init$return_codes() == 1) {
stop("We are unable to create initial values from a model with no samples. Please check the results of the model used for inits before continuing.")
} else if (!any(names(model_variables$parameters) %in% init$metadata()$stan_variables)) {
stop("None of the names of the parameters for the model used for initial values match the names of parameters from the model currently running.")
}
draws_df = init$draws(format = "df")
if (is.null(model_variables)) {
model_variables = list(parameters = colnames(draws_df)[2:(length(colnames(draws_df)) - 3)])
}
init_draws_df = posterior::resample_draws(draws_df, ndraws = num_procs,
method = "simple")
init_draws_lst_lst = process_init(init_draws_df,
num_procs = num_procs, model_variables = model_variables, warn_partial)
return(init_draws_lst_lst)
}

#' Validate initial values
Expand Down
2 changes: 1 addition & 1 deletion R/cmdstanr-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@
#' @inherit cmdstan_model examples
#' @import R6
#'
NULL
"_PACKAGE"

if (getRversion() >= "2.15.1") utils::globalVariables(c("self", "private", "super"))
6 changes: 3 additions & 3 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -518,11 +518,11 @@ unconstrain_variables <- function(variables) {
" not provided!", call. = FALSE)
}

# Remove zero-length parameters from model_variables, otherwise process_init_list
# Remove zero-length parameters from model_variables, otherwise process_init
# warns about missing inputs
model_variables$parameters <- model_variables$parameters[nonzero_length_params]

stan_pars <- process_init_list(list(variables), num_procs = 1, model_variables)
stan_pars <- process_init(list(variables), num_procs = 1, model_variables)
private$model_methods_env_$unconstrain_variables(private$model_methods_env_$model_ptr_, stan_pars)
}
CmdStanFit$set("public", name = "unconstrain_variables", value = unconstrain_variables)
Expand Down Expand Up @@ -594,7 +594,7 @@ unconstrain_draws <- function(files = NULL, draws = NULL,
# but not in metadata()$variables
nonzero_length_params <- names(model_variables$parameters) %in% model_par_names

# Remove zero-length parameters from model_variables, otherwise process_init_list
# Remove zero-length parameters from model_variables, otherwise process_init
# warns about missing inputs
pars <- names(model_variables$parameters[nonzero_length_params])

Expand Down
2 changes: 1 addition & 1 deletion man/model-method-check_syntax.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/model-method-compile.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/model-method-diagnose.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/model-method-expose_functions.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/model-method-format.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/model-method-generate-quantities.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/model-method-laplace.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit ae46447

Please sign in to comment.