Skip to content

Commit

Permalink
update docs and main function signature
Browse files Browse the repository at this point in the history
  • Loading branch information
SteveBronder committed Jan 18, 2024
1 parent c69ba62 commit dc97a7e
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 3 deletions.
8 changes: 7 additions & 1 deletion R/args.R
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,7 @@ validate_variational_args <- function(self) {
if (!is.null(self$eval_elbo)) {
self$eval_elbo <- as.integer(self$eval_elbo)
}
checkmate::assert_integerish(self$output_samples, null.ok = TRUE,
checkmate::assert_inset_cmdstan_pathtegerish(self$output_samples, null.ok = TRUE,
lower = 1, len = 1, .var.name = "draws")
if (!is.null(self$output_samples)) {
self$output_samples <- as.integer(self$output_samples)
Expand Down Expand Up @@ -972,8 +972,14 @@ validate_pathfinder_args <- function(self) {
if (!is.null(self$save_single_paths)) {
self$save_single_paths <- 0
}
if (!is.null(self$psis_resample) && is.logical(self$psis_resample)) {
self$psis_resample = as.integer(self$psis_resample)
}
checkmate::assert_integerish(self$psis_resample, null.ok = TRUE,
lower = 0, upper = 1, len = 1)
if (!is.null(self$calculate_lp) && is.logical(self$calculate_lp)) {
self$calculate_lp = as.integer(self$calculate_lp)
}
checkmate::assert_integerish(self$calculate_lp, null.ok = TRUE,
lower = 0, upper = 1, len = 1)

Expand Down
17 changes: 16 additions & 1 deletion R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -1884,6 +1884,17 @@ CmdStanModel$set("public", name = "variational", value = variational)
#' calculating the ELBO of the approximation at each iteration of LBFGS.
#' @param save_single_paths (logical) Whether to save the results of single
#' pathfinder runs in multi-pathfinder.
#' @param psis_sample (logical) Whether to perform pareto smoothed importance sampling.
#' If `TRUE`, the number of draws returned will be equal to `draws`.
#' If `FALSE`, the number of draws returned will be equal to `single_path_draws * num_paths`.
#' @param calculate_lp (logical) Whether to calculate the log probability of the draws.
#' If `TRUE`, the log probability will be calculated and given in the output.
#' If `FALSE`, the log probability will only be returned for draws used to determine the
#' ELBO in the pathfinder steps. All other draws will have a log probability of `NA`.
#' A value of `FALSE` will also turn off pareto smoothed importance sampling as the
#' lp calculation is needed for PSIS.
#' @param save_single_paths (logical) Whether to save the results of single
#' pathfinder runs in multi-pathfinder.
#' @return A [`CmdStanPathfinder`] object.
#'
#' @template seealso-docs
Expand Down Expand Up @@ -1912,6 +1923,8 @@ pathfinder <- function(data = NULL,
max_lbfgs_iters = NULL,
num_elbo_draws = NULL,
save_single_paths = NULL,
psis_resample = NULL,
calculate_lp = NULL,
show_messages = TRUE,
show_exceptions = TRUE) {
procs <- CmdStanProcs$new(
Expand All @@ -1937,7 +1950,9 @@ pathfinder <- function(data = NULL,
num_paths = num_paths,
max_lbfgs_iters = max_lbfgs_iters,
num_elbo_draws = num_elbo_draws,
save_single_paths = save_single_paths
save_single_paths = save_single_paths,
psis_resample = psis_resample,
calculate_lp = calculate_lp
)
args <- CmdStanArgs$new(
method_args = pathfinder_args,
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-model-pathfinder.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
context("model-pathfinder")

set_cmdstan_path()
set_cmdstan_path("/mnt/home/sbronder/opensource/stan/origin/cmdstan")
stan_program <- testing_stan_file("bernoulli")
mod <- testing_model("bernoulli")
stan_program_fp <- testing_stan_file("bernoulli_fp")
Expand Down

0 comments on commit dc97a7e

Please sign in to comment.