Skip to content

Commit

Permalink
update docs and failing test
Browse files Browse the repository at this point in the history
  • Loading branch information
SteveBronder committed Apr 18, 2024
1 parent 238e9bb commit 9821640
Show file tree
Hide file tree
Showing 10 changed files with 109 additions and 75 deletions.
37 changes: 29 additions & 8 deletions R/args.R
Original file line number Diff line number Diff line change
Expand Up @@ -1268,23 +1268,29 @@ process_init_approx <- function(init, num_procs, model_variables = NULL,
# Calculate unique draws based on 'lw' using base R functions
unique_draws = length(unique(draws_df$lw))
if (num_procs > unique_draws) {
if (inherits(init, "CmdStanPathfinder")) {
stop(paste0("Not enough distinct draws (", num_procs, ") in pathfinder fit to create inits. Try running Pathfinder with psis_resample=FALSE"))
if (inherits(init, " CmdStanPathfinder ")) {
algo_name = " Pathfinder "
extra_msg = " Try running Pathfinder with psis_resample=FALSE."
} else if (inherits(init, "CmdStanVB")) {
algo_name = " CmdStanVB "
extra_msg = ""
} else if (inherits(init, " CmdStanLaplace ")) {
algo_name = " CmdStanLaplace "
extra_msg = ""
} else {
stop(paste0("Not enough distinct draws (", num_procs, ") to create inits."))
algo_name = ""
extra_msg = ""
}
stop(paste0("Not enough distinct draws (", num_procs, ") in", algo_name ,
"fit to create inits.", extra_msg))
}
if (unique_draws < (0.95 * nrow(draws_df))) {
temp_df = stats::aggregate(.draw ~ lw, data = draws_df, FUN = min)
draws_df = posterior::as_draws_df(merge(temp_df, draws_df, by = 'lw'))
draws_df$weight = exp(draws_df$lw - max(draws_df$lw))
} else {
if (inherits(init, "CmdStanPathfinder") && (init$metadata()$psis_resample || !init$metadata()$calculate_lp)) {
draws_df$weight = rep(1.0, nrow(draws_df))
} else {
draws_df$weight = posterior::pareto_smooth(
exp(draws_df$lw - max(draws_df$lw)), tail = "right", return_k=FALSE)
}
}
init_draws_df = posterior::resample_draws(draws_df, ndraws = num_procs,
weights = draws_df$weight, method = "simple_no_replace")
Expand All @@ -1308,7 +1314,22 @@ process_init_approx <- function(init, num_procs, model_variables = NULL,
process_init.CmdStanPathfinder <- function(init, num_procs, model_variables = NULL,
warn_partial = getOption("cmdstanr_warn_inits", TRUE),
...) {
process_init_approx(init, num_procs, model_variables, warn_partial)
if (!init$metadata()$calculate_lp) {
validate_fit_init(init, model_variables)
# Convert from data.table to data.frame
draws_df = init$draws(format = "df")
if (is.null(model_variables)) {
model_variables = list(parameters = colnames(draws_df)[3:(length(colnames(draws_df)) - 3)])
}
draws_df$weight = rep(1.0, nrow(draws_df))
init_draws_df = posterior::resample_draws(draws_df, ndraws = num_procs,
weights = draws_df$weight, method = "simple_no_replace")
init_draws_lst = process_init(init_draws_df,
num_procs = num_procs, model_variables = model_variables, warn_partial)
return(init_draws_lst)
} else {
process_init_approx(init, num_procs, model_variables, warn_partial)
}
}

#' Write initial values to files if provided as a `CmdStanVB` class
Expand Down
18 changes: 10 additions & 8 deletions man-roxygen/model-common-args.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,21 @@
#' has argument `chain_id` it will be supplied with the chain id (from 1 to
#' number of chains) when called to generate the initial values. See
#' **Examples**.
#' * A [`CmdStanMCMC`], [`CmdStanMLE`], [`CmdStanVB`], or [`CmdStanPathfinder`]
#' fit object. If the fit object's parameters are only a subset of the model
#' * A [`CmdStanMCMC`], [`CmdStanMLE`], [`CmdStanVB`], [`CmdStanPathfinder`],
#' or [`CmdStanLaplace`] fit object.
#' If the fit object's parameters are only a subset of the model
#' parameters then the other parameters will be drawn by Stan's default
#' initialization. The fit object must have at least some parameters that are the
#' same name and dimensions as the current Stan model. For the `sampling` and
#' `pathfinder` method, if the fit object has less samples than the requested
#' same name and dimensions as the current Stan model. For the `sample` and
#' `pathfinder` method, if the fit object has fewer draws than the requested
#' number of chains/paths then the inits will be drawn using sampling with
#' replacement. Otherwise sampling without replacement will be used.
#' When a [`CmdStanPathfinder`] fit object is used as the init, if
#' `psis_resample` was set to `FALSE` and `calculate_lp` was
#' set to `TRUE` (default), then PSIS resampling will be used as weights.
#' if `calculate_lp` is `FALSE` then sampling without replacement will be used
#' to select the draws.
#'. `psis_resample` was set to `FALSE` and `calculate_lp` was
#' set to `TRUE` (default), then resampling without replacement with Pareto
#' smoothed weights will be used. If `psis_resample` was set to `TRUE` or
#' `calculate_lp` was set to `FALSE` then sampling without replacement with
#' uniform weights will be used to select the draws.
#' PSIS resampling is used to select the draws for [`CmdStanVB`] fit objects.
#'
#' * A type inheriting from `posterior::draws`. If the draws object has less
Expand Down
18 changes: 10 additions & 8 deletions man/model-method-diagnose.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 10 additions & 8 deletions man/model-method-laplace.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 10 additions & 8 deletions man/model-method-optimize.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 10 additions & 8 deletions man/model-method-pathfinder.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 10 additions & 8 deletions man/model-method-sample.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 10 additions & 8 deletions man/model-method-sample_mpi.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 9821640

Please sign in to comment.