Skip to content

Commit

Permalink
update pathfinder args for psis_resample and lp_calculate (#903)
Browse files Browse the repository at this point in the history
* update pathfinder args for psis_resample and lp_calculate

---------

Co-authored-by: Andrew Johnson <[email protected]>
  • Loading branch information
SteveBronder and andrjohns authored Jan 26, 2024
1 parent 3c7a1a9 commit d3b455f
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 4 deletions.
20 changes: 18 additions & 2 deletions R/args.R
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,9 @@ PathfinderArgs <- R6::R6Class(
num_paths = NULL,
max_lbfgs_iters = NULL,
num_elbo_draws = NULL,
save_single_paths = NULL) {
save_single_paths = NULL,
psis_resample = NULL,
calculate_lp = NULL) {
self$init_alpha <- init_alpha
self$tol_obj <- tol_obj
self$tol_rel_obj <- tol_rel_obj
Expand All @@ -580,6 +582,8 @@ PathfinderArgs <- R6::R6Class(
self$max_lbfgs_iters <- max_lbfgs_iters
self$num_elbo_draws <- num_elbo_draws
self$save_single_paths <- save_single_paths
self$psis_resample <- psis_resample
self$calculate_lp <- calculate_lp
invisible(self)
},

Expand Down Expand Up @@ -608,7 +612,9 @@ PathfinderArgs <- R6::R6Class(
.make_arg("num_paths"),
.make_arg("max_lbfgs_iters"),
.make_arg("num_elbo_draws"),
.make_arg("save_single_paths")
.make_arg("save_single_paths"),
.make_arg("psis_resample"),
.make_arg("calculate_lp")
)
new_args <- do.call(c, new_args)
c(args, new_args)
Expand Down Expand Up @@ -966,6 +972,16 @@ validate_pathfinder_args <- function(self) {
if (!is.null(self$save_single_paths)) {
self$save_single_paths <- 0
}
if (!is.null(self$psis_resample) && is.logical(self$psis_resample)) {
self$psis_resample = as.integer(self$psis_resample)
}
checkmate::assert_integerish(self$psis_resample, null.ok = TRUE,
lower = 0, upper = 1, len = 1)
if (!is.null(self$calculate_lp) && is.logical(self$calculate_lp)) {
self$calculate_lp = as.integer(self$calculate_lp)
}
checkmate::assert_integerish(self$calculate_lp, null.ok = TRUE,
lower = 0, upper = 1, len = 1)


# check args only available for lbfgs and bfgs
Expand Down
17 changes: 16 additions & 1 deletion R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -1887,6 +1887,17 @@ CmdStanModel$set("public", name = "variational", value = variational)
#' calculating the ELBO of the approximation at each iteration of LBFGS.
#' @param save_single_paths (logical) Whether to save the results of single
#' pathfinder runs in multi-pathfinder.
#' @param psis_resample (logical) Whether to perform pareto smoothed importance sampling.
#' If `TRUE`, the number of draws returned will be equal to `draws`.
#' If `FALSE`, the number of draws returned will be equal to `single_path_draws * num_paths`.
#' @param calculate_lp (logical) Whether to calculate the log probability of the draws.
#' If `TRUE`, the log probability will be calculated and given in the output.
#' If `FALSE`, the log probability will only be returned for draws used to determine the
#' ELBO in the pathfinder steps. All other draws will have a log probability of `NA`.
#' A value of `FALSE` will also turn off pareto smoothed importance sampling as the
#' lp calculation is needed for PSIS.
#' @param save_single_paths (logical) Whether to save the results of single
#' pathfinder runs in multi-pathfinder.
#' @return A [`CmdStanPathfinder`] object.
#'
#' @template seealso-docs
Expand Down Expand Up @@ -1915,6 +1926,8 @@ pathfinder <- function(data = NULL,
max_lbfgs_iters = NULL,
num_elbo_draws = NULL,
save_single_paths = NULL,
psis_resample = NULL,
calculate_lp = NULL,
show_messages = TRUE,
show_exceptions = TRUE) {
procs <- CmdStanProcs$new(
Expand All @@ -1940,7 +1953,9 @@ pathfinder <- function(data = NULL,
num_paths = num_paths,
max_lbfgs_iters = max_lbfgs_iters,
num_elbo_draws = num_elbo_draws,
save_single_paths = save_single_paths
save_single_paths = save_single_paths,
psis_resample = psis_resample,
calculate_lp = calculate_lp
)
args <- CmdStanArgs$new(
method_args = pathfinder_args,
Expand Down
13 changes: 13 additions & 0 deletions man/model-method-pathfinder.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion tests/testthat/test-model-pathfinder.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ ok_arg_values <- list(
draws = 100,
num_paths = 4,
max_lbfgs_iters = 100,
save_single_paths = FALSE)
save_single_paths = FALSE,
calculate_lp = TRUE,
psis_resample=TRUE)

# using any one of these should cause sample() to error
bad_arg_values <- list(
Expand Down

0 comments on commit d3b455f

Please sign in to comment.