Skip to content

Commit

Permalink
Merge pull request #848 from stan-dev/feature/pathfinder
Browse files Browse the repository at this point in the history
Pathfinder
  • Loading branch information
jgabry authored Nov 8, 2023
2 parents fca3844 + 094d797 commit 357a07e
Show file tree
Hide file tree
Showing 35 changed files with 1,015 additions and 68 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ Authors@R:
person(given = c("William", "Michael"), family = "Landau", role = "ctb",
email = "[email protected]", comment = c(ORCID = "0000-0003-1878-3253")),
person(given = "Jacob", family = "Socolar", role = "ctb"),
person(given = "Martin", family = "Modrák", role = "ctb"))
person(given = "Martin", family = "Modrák", role = "ctb"),
person(given = "Steve", family = "Bronder", role = "ctb"))
Description: A lightweight interface to 'Stan' <https://mc-stan.org>.
The 'CmdStanR' interface is an alternative to 'RStan' that calls the command
line interface for compilation and running algorithms instead of interfacing
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ S3method(as_draws,CmdStanGQ)
S3method(as_draws,CmdStanLaplace)
S3method(as_draws,CmdStanMCMC)
S3method(as_draws,CmdStanMLE)
S3method(as_draws,CmdStanPathfinder)
S3method(as_draws,CmdStanVB)
export(as_cmdstan_fit)
export(as_draws)
Expand Down
130 changes: 129 additions & 1 deletion R/args.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#' * `OptimizeArgs`: stores arguments specific to `method=optimize`.
#' * `LaplaceArgs`: stores arguments specific to `method=laplace`.
#' * `VariationalArgs`: stores arguments specific to `method=variational`
#' * `PathfinderArgs`: stores arguments specific to `method=pathfinder`
#' * `GenerateQuantitiesArgs`: stores arguments specific to `method=generate_quantities`
#' * `DiagnoseArgs`: stores arguments specific to `method=diagnose`
#'
Expand All @@ -41,7 +42,8 @@ CmdStanArgs <- R6::R6Class(
output_basename = NULL,
sig_figs = NULL,
opencl_ids = NULL,
model_variables = NULL) {
model_variables = NULL,
num_threads = NULL) {

self$model_name <- model_name
self$stan_code <- stan_code
Expand Down Expand Up @@ -82,6 +84,7 @@ CmdStanArgs <- R6::R6Class(
}
self$init <- init
self$opencl_ids <- opencl_ids
self$num_threads = NULL
self$method_args$validate(num_procs = length(self$proc_ids))
self$validate()
},
Expand Down Expand Up @@ -179,6 +182,9 @@ CmdStanArgs <- R6::R6Class(
if (!is.null(self$opencl_ids)) {
args$opencl <- c("opencl", paste0("platform=", self$opencl_ids[1]), paste0("device=", self$opencl_ids[2]))
}
if (!is.null(self$num_threads)) {
num_threads <- c(args$output, paste0("num_threads=", self$num_threads))
}
args <- do.call(c, append(args, list(use.names = FALSE)))
self$method_args$compose(idx, args)
},
Expand Down Expand Up @@ -541,6 +547,75 @@ VariationalArgs <- R6::R6Class(
)
)

# PathfinderArgs ---------------------------------------------------------

PathfinderArgs <- R6::R6Class(
"PathfinderArgs",
lock_objects = FALSE,
public = list(
method = "pathfinder",
initialize = function(init_alpha = NULL,
tol_obj = NULL,
tol_rel_obj = NULL,
tol_grad = NULL,
tol_rel_grad = NULL,
tol_param = NULL,
history_size = NULL,
single_path_draws = NULL,
draws = NULL,
num_paths = NULL,
max_lbfgs_iters = NULL,
num_elbo_draws = NULL,
save_single_paths = NULL) {
self$init_alpha <- init_alpha
self$tol_obj <- tol_obj
self$tol_rel_obj <- tol_rel_obj
self$tol_grad <- tol_grad
self$tol_rel_grad <- tol_rel_grad
self$tol_param <- tol_param
self$history_size <- history_size
self$num_psis_draws <- draws
self$num_draws <- single_path_draws
self$num_paths <- num_paths
self$max_lbfgs_iters <- max_lbfgs_iters
self$num_elbo_draws <- num_elbo_draws
self$save_single_paths <- save_single_paths
invisible(self)
},

validate = function(num_procs) {
validate_pathfinder_args(self)
},

# Compose arguments to CmdStan command for pathfinder-specific
# non-default arguments. Works the same way as compose for sampler args,
# but `idx` (multiple pathfinders are handled in cmdstan)
compose = function(idx = NULL, args = NULL) {
.make_arg <- function(arg_name) {
compose_arg(self, arg_name, idx = NULL)
}
new_args <- list(
"method=pathfinder",
.make_arg("init_alpha"),
.make_arg("tol_obj"),
.make_arg("tol_rel_obj"),
.make_arg("tol_grad"),
.make_arg("tol_rel_grad"),
.make_arg("tol_param"),
.make_arg("history_size"),
.make_arg("num_psis_draws"),
.make_arg("num_draws"),
.make_arg("num_paths"),
.make_arg("max_lbfgs_iters"),
.make_arg("num_elbo_draws"),
.make_arg("save_single_paths")
)
new_args <- do.call(c, new_args)
c(args, new_args)
}
)
)

# DiagnoseArgs -------------------------------------------------------------

DiagnoseArgs <- R6::R6Class(
Expand Down Expand Up @@ -854,6 +929,59 @@ validate_variational_args <- function(self) {
invisible(TRUE)
}

#' Validate arguments for pathfinder inference
#' @noRd
#' @param self A `PathfinderArgs` object.
#' @return `TRUE` invisibly unless an error is thrown.
validate_pathfinder_args <- function(self) {

checkmate::assert_integerish(self$max_lbfgs_iters, lower = 1, null.ok = TRUE, len = 1)
if (!is.null(self$max_lbfgs_iters)) {
self$iter <- as.integer(self$max_lbfgs_iters)
}
checkmate::assert_integerish(self$num_paths, lower = 1, null.ok = TRUE,
len = 1)
if (!is.null(self$num_paths)) {
self$num_paths <- as.integer(self$num_paths)
}
checkmate::assert_integerish(self$num_draws, lower = 1, null.ok = TRUE,
len = 1, .var.name = "single_path_draws")
if (!is.null(self$num_draws)) {
self$num_draws <- as.integer(self$num_draws)
}
checkmate::assert_integerish(self$num_psis_draws, lower = 1, null.ok = TRUE,
len = 1, .var.name = "draws")
if (!is.null(self$num_psis_draws)) {
self$num_psis_draws <- as.integer(self$num_psis_draws)
}
checkmate::assert_integerish(self$num_elbo_draws, lower = 1, null.ok = TRUE, len = 1)
if (!is.null(self$num_elbo_draws)) {
self$num_elbo_draws <- as.integer(self$num_elbo_draws)
}
if (!is.null(self$save_single_paths) && is.logical(self$save_single_paths)) {
self$save_single_paths = as.integer(self$save_single_paths)
}
checkmate::assert_integerish(self$save_single_paths, null.ok = TRUE,
lower = 0, upper = 1, len = 1)
if (!is.null(self$save_single_paths)) {
self$save_single_paths <- 0
}


# check args only available for lbfgs and bfgs
bfgs_args <- c("init_alpha", "tol_obj", "tol_rel_obj", "tol_grad", "tol_rel_grad", "tol_param")
for (arg in bfgs_args) {
checkmate::assert_number(self[[arg]], .var.name = arg, lower = 0, null.ok = TRUE)
}

if (!is.null(self$history_size)) {
checkmate::assert_integerish(self$history_size, lower = 1, len = 1, null.ok = FALSE)
self$history_size <- as.integer(self$history_size)
}

invisible(TRUE)
}


# Validation helpers ------------------------------------------------------

Expand Down
39 changes: 38 additions & 1 deletion R/csv.R
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,10 @@ read_cmdstan_csv <- function(files,
if (length(variables) > 0) {
draws_list_id <- length(draws) + 1
warmup_draws_list_id <- length(warmup_draws) + 1
if (metadata$method == "pathfinder") {
metadata$variables = union(metadata$sampler_diagnostics, metadata$variables)
variables = union(metadata$sampler_diagnostics, variables)
}
suppressWarnings(
draws[[draws_list_id]] <- data.table::fread(
cmd = fread_cmd,
Expand Down Expand Up @@ -445,6 +449,21 @@ read_cmdstan_csv <- function(files,
metadata = metadata,
generated_quantities = draws
)
} else if (metadata$method == "pathfinder") {
if (is.null(format)) {
format <- "draws_matrix"
}
as_draws_format <- as_draws_format_fun(format)
if (length(draws) == 0) {
pathfinder_draws <- NULL
} else {
pathfinder_draws <- do.call(as_draws_format, list(draws[[1]][, colnames(draws[[1]]), drop = FALSE]))
posterior::variables(pathfinder_draws) <- repaired_variables
}
list(
metadata = metadata,
draws = pathfinder_draws
)
}
}

Expand Down Expand Up @@ -477,6 +496,7 @@ as_cmdstan_fit <- function(files, check_diagnostics = TRUE, format = getOption("
"sample" = CmdStanMCMC_CSV$new(csv_contents, files, check_diagnostics),
"optimize" = CmdStanMLE_CSV$new(csv_contents, files),
"variational" = CmdStanVB_CSV$new(csv_contents, files),
"pathfinder" = CmdStanPathfinder_CSV$new(csv_contents, files),
"laplace" = CmdStanLaplace_CSV$new(csv_contents, files)
)
}
Expand Down Expand Up @@ -576,6 +596,23 @@ CmdStanVB_CSV <- R6::R6Class(
private = list(output_files_ = NULL)
)

CmdStanPathfinder_CSV <- R6::R6Class(
classname = "CmdStanPathfinder_CSV",
inherit = CmdStanPathfinder,
public = list(
initialize = function(csv_contents, files) {
private$output_files_ <- files
private$draws_ <- csv_contents$draws
private$metadata_ <- csv_contents$metadata
},
output_files = function(...) {
private$output_files_
}
),
private = list(output_files_ = NULL)
)


# these methods are unavailable because there's no CmdStanRun object
unavailable_methods_CmdStanFit_CSV <- c(
"cmdstan_diagnose", "cmdstan_summary",
Expand Down Expand Up @@ -642,7 +679,7 @@ read_csv_metadata <- function(csv_file) {
"\""
)
} else {
fread_cmd <- paste0("grep '^[#a-zA-Z]' --color=never '", csv_file, "'")
fread_cmd <- paste0("grep '^[#a-zA-Z]' --color=never '", path.expand(csv_file), "'")
}
suppressWarnings(
metadata <- data.table::fread(
Expand Down
2 changes: 1 addition & 1 deletion R/example.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
#'
cmdstanr_example <-
function(example = c("logistic", "schools", "schools_ncp"),
method = c("sample", "optimize", "laplace", "variational", "diagnose"),
method = c("sample", "optimize", "laplace", "variational", "pathfinder", "diagnose"),
...,
quiet = TRUE,
force_recompile = getOption("cmdstanr_force_recompile", default = FALSE)) {
Expand Down
80 changes: 80 additions & 0 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -2024,6 +2024,80 @@ CmdStanVB <- R6::R6Class(
)
CmdStanVB$set("public", name = "lp_approx", value = lp_approx)

# CmdStanPathfinder ---------------------------------------------------------------
#' CmdStanPathfinder objects
#'
#' @name CmdStanPathfinder
#' @family fitted model objects
#' @template seealso-docs
#'
#' @description A `CmdStanPathfinder` object is the fitted model object returned by the
#' [`$pathfinder()`][model-method-pathfinder] method of a
#' [`CmdStanModel`] object.
#'
#' @section Methods: `CmdStanPathfinder` objects have the following associated methods,
#' all of which have their own (linked) documentation pages.
#'
#' ## Extract contents of fitted model object
#'
#' |**Method**|**Description**|
#' |:----------|:---------------|
#' [`$draws()`][fit-method-draws] | Return approximate posterior draws as a [`draws_matrix`][posterior::draws_matrix]. |
#' [`$lp()`][fit-method-lp] | Return the total log probability density (`target`) computed in the model block of the Stan program. |
#' [`$lp_approx()`][fit-method-lp] | Return the log density of the approximation to the posterior. |
#' [`$init()`][fit-method-init] | Return user-specified initial values. |
#' [`$metadata()`][fit-method-metadata] | Return a list of metadata gathered from the CmdStan CSV files. |
#' [`$code()`][fit-method-code] | Return Stan code as a character vector. |
#'
#' ## Summarize inferences
#'
#' |**Method**|**Description**|
#' |:----------|:---------------|
#' [`$summary()`][fit-method-summary] | Run [`posterior::summarise_draws()`][posterior::draws_summary]. |
#' [`$cmdstan_summary()`][fit-method-cmdstan_summary] | Run and print CmdStan's `bin/stansummary`. |
#'
#' ## Save fitted model object and temporary files
#'
#' |**Method**|**Description**|
#' |:----------|:---------------|
#' [`$save_object()`][fit-method-save_object] | Save fitted model object to a file. |
#' [`$save_output_files()`][fit-method-save_output_files] | Save output CSV files to a specified location. |
#' [`$save_data_file()`][fit-method-save_data_file] | Save JSON data file to a specified location. |
#' [`$save_latent_dynamics_files()`][fit-method-save_latent_dynamics_files] | Save diagnostic CSV files to a specified location. |
#'
#' ## Report run times, console output, return codes
#'
#' |**Method**|**Description**|
#' |:----------|:---------------|
#' [`$time()`][fit-method-time] | Report the total run time. |
#' [`$output()`][fit-method-output] | Pretty print the output that was printed to the console. |
#' [`$return_codes()`][fit-method-return_codes] | Return the return codes from the CmdStan runs. |
#'
CmdStanPathfinder <- R6::R6Class(
classname = "CmdStanPathfinder",
inherit = CmdStanFit,
public = list(),
private = list(
# inherits draws_ and metadata_ slots from CmdStanFit
read_csv_ = function(format = getOption("cmdstanr_draws_format", "draws_matrix")) {
if (!length(self$output_files(include_failed = FALSE))) {
stop("Pathfinder failed. Unable to retrieve the draws.", call. = FALSE)
}
csv_contents <- read_cmdstan_csv(self$output_files(), format = format)
private$draws_ <- csv_contents$draws
private$metadata_ <- csv_contents$metadata
invisible(self)
}
)
)

#' @rdname fit-method-lp
lp_approx <- function() {
as.numeric(self$draws()[, "lp_approx__"])
}
CmdStanPathfinder$set("public", name = "lp_approx", value = lp_approx)



# CmdStanGQ ---------------------------------------------------------------
#' CmdStanGQ objects
Expand Down Expand Up @@ -2290,3 +2364,9 @@ as_draws.CmdStanVB <- function(x, ...) {
as_draws.CmdStanGQ <- function(x, ...) {
x$draws(...)
}

#' @rdname as_draws.CmdStanMCMC
#' @export
as_draws.CmdStanPathfinder <- function(x, ...) {
x$draws(...)
}
Loading

0 comments on commit 357a07e

Please sign in to comment.