Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pathfinder #848

Merged
merged 25 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
3c82767
update to have a route for pathfinder
SteveBronder Dec 7, 2021
56d65d3
adds multi pathfinder paths
SteveBronder Dec 8, 2021
518089a
fix pathfinder names
SteveBronder Dec 8, 2021
14eab9a
update to allow for num_vals pathfinder command
SteveBronder Jan 26, 2022
6be1c02
Merge pull request #727 from stan-dev/master
jgabry Jul 31, 2023
697b07f
Merge remote-tracking branch 'origin/master' into feature/pathfinder
SteveBronder Aug 30, 2023
b7ca3cf
fix lp_approx__ thing
SteveBronder Aug 30, 2023
10ea08a
update to remove single pathfinder save and write the tests
SteveBronder Sep 8, 2023
dc5d402
update tests
SteveBronder Sep 11, 2023
25d484b
Merge branch 'master' into feature/pathfinder
andrjohns Sep 13, 2023
a05470e
Merge remote-tracking branch 'origin/master' into feature/pathfinder
SteveBronder Sep 25, 2023
0913b32
update with master
SteveBronder Sep 25, 2023
7070020
Fix argument names to line up with cmdstanpy args
SteveBronder Sep 25, 2023
c15bf48
update tests
SteveBronder Sep 25, 2023
1edd938
update with master
SteveBronder Sep 25, 2023
49eaa40
update roxygen
SteveBronder Sep 26, 2023
3a27f5f
Merge remote-tracking branch 'origin' into feature/pathfinder
SteveBronder Sep 26, 2023
f606056
set num_draws to single_path_draws and num_psis_draws to draws
SteveBronder Sep 29, 2023
2c40d76
update docs
SteveBronder Sep 29, 2023
0526f11
update error message var names
SteveBronder Oct 5, 2023
ec7f6c0
Update R/model.R
jgabry Nov 1, 2023
3cab26a
fix doc issues in review comments
jgabry Nov 1, 2023
204f217
Fix bug which returned -1 draws
SteveBronder Nov 3, 2023
387dfd2
Add Steve Bronder to description file
SteveBronder Nov 3, 2023
094d797
fix test error
SteveBronder Nov 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ S3method(as_draws,CmdStanGQ)
S3method(as_draws,CmdStanLaplace)
S3method(as_draws,CmdStanMCMC)
S3method(as_draws,CmdStanMLE)
S3method(as_draws,CmdStanPathfinder)
S3method(as_draws,CmdStanVB)
export(as_cmdstan_fit)
export(as_draws)
Expand Down
120 changes: 119 additions & 1 deletion R/args.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#' * `OptimizeArgs`: stores arguments specific to `method=optimize`.
#' * `LaplaceArgs`: stores arguments specific to `method=laplace`.
#' * `VariationalArgs`: stores arguments specific to `method=variational`
#' * `PathfinderArgs`: stores arguments specific to `method=pathfinder`
#' * `GenerateQuantitiesArgs`: stores arguments specific to `method=generate_quantities`
#' * `DiagnoseArgs`: stores arguments specific to `method=diagnose`
#'
Expand All @@ -41,7 +42,8 @@ CmdStanArgs <- R6::R6Class(
output_basename = NULL,
sig_figs = NULL,
opencl_ids = NULL,
model_variables = NULL) {
model_variables = NULL,
num_threads = NULL) {

self$model_name <- model_name
self$stan_code <- stan_code
Expand Down Expand Up @@ -82,6 +84,7 @@ CmdStanArgs <- R6::R6Class(
}
self$init <- init
self$opencl_ids <- opencl_ids
self$num_threads = NULL
self$method_args$validate(num_procs = length(self$proc_ids))
self$validate()
},
Expand Down Expand Up @@ -179,6 +182,9 @@ CmdStanArgs <- R6::R6Class(
if (!is.null(self$opencl_ids)) {
args$opencl <- c("opencl", paste0("platform=", self$opencl_ids[1]), paste0("device=", self$opencl_ids[2]))
}
if (!is.null(self$num_threads)) {
num_threads <- c(args$output, paste0("num_threads=", self$num_threads))
}
args <- do.call(c, append(args, list(use.names = FALSE)))
self$method_args$compose(idx, args)
},
Expand Down Expand Up @@ -541,6 +547,72 @@ VariationalArgs <- R6::R6Class(
)
)

# PathfinderArgs ---------------------------------------------------------

PathfinderArgs <- R6::R6Class(
"PathfinderArgs",
lock_objects = FALSE,
public = list(
method = "pathfinder",
initialize = function(init_alpha = NULL,
tol_obj = NULL,
tol_rel_obj = NULL,
tol_grad = NULL,
tol_rel_grad = NULL,
tol_param = NULL,
history_size = NULL,
num_draws = NULL,
num_paths = NULL,
max_lbfgs_iters = NULL,
num_elbo_draws = NULL,
save_single_paths = NULL) {
self$init_alpha <- init_alpha
self$tol_obj <- tol_obj
self$tol_rel_obj <- tol_rel_obj
self$tol_grad <- tol_grad
self$tol_rel_grad <- tol_rel_grad
self$tol_param <- tol_param
self$history_size <- history_size
self$num_draws <- num_draws
self$num_paths <- num_paths
self$max_lbfgs_iters <- max_lbfgs_iters
self$num_elbo_draws <- num_elbo_draws
self$save_single_paths <- save_single_paths
invisible(self)
},

validate = function(num_procs) {
validate_pathfinder_args(self)
},

# Compose arguments to CmdStan command for pathfinder-specific
# non-default arguments. Works the same way as compose for sampler args,
# but `idx` (multiple pathfinders are handled in cmdstan)
compose = function(idx = NULL, args = NULL) {
.make_arg <- function(arg_name) {
compose_arg(self, arg_name, idx = NULL)
}
new_args <- list(
"method=pathfinder",
.make_arg("init_alpha"),
.make_arg("tol_obj"),
.make_arg("tol_rel_obj"),
.make_arg("tol_grad"),
.make_arg("tol_rel_grad"),
.make_arg("tol_param"),
.make_arg("history_size"),
.make_arg("num_draws"),
.make_arg("num_paths"),
.make_arg("max_lbfgs_iters"),
.make_arg("num_elbo_draws"),
.make_arg("save_single_paths")
)
new_args <- do.call(c, new_args)
c(args, new_args)
}
)
)

# DiagnoseArgs -------------------------------------------------------------

DiagnoseArgs <- R6::R6Class(
Expand Down Expand Up @@ -854,6 +926,52 @@ validate_variational_args <- function(self) {
invisible(TRUE)
}

#' Validate arguments for pathfinder inference
#' @noRd
#' @param self A `PathfinderArgs` object.
#' @return `TRUE` invisibly unless an error is thrown.
validate_pathfinder_args <- function(self) {

checkmate::assert_integerish(self$max_lbfgs_iters, lower = 1, null.ok = TRUE, len = 1)
if (!is.null(self$max_lbfgs_iters)) {
self$iter <- as.integer(self$max_lbfgs_iters)
}
checkmate::assert_integerish(self$num_paths, lower = 1, null.ok = TRUE, len = 1)
if (!is.null(self$num_paths)) {
self$num_paths <- as.integer(self$num_paths)
}
checkmate::assert_integerish(self$num_draws, lower = 1, null.ok = TRUE, len = 1)
if (!is.null(self$num_draws)) {
self$num_draws <- as.integer(self$num_draws)
}
checkmate::assert_integerish(self$num_elbo_draws, lower = 1, null.ok = TRUE, len = 1)
if (!is.null(self$num_elbo_draws)) {
self$num_elbo_draws <- as.integer(self$num_elbo_draws)
}
if (!is.null(self$save_single_paths) && is.logical(self$save_single_paths)) {
self$save_single_paths = as.integer(self$save_single_paths)
}
checkmate::assert_integerish(self$save_single_paths, null.ok = TRUE,
lower = 0, upper = 1, len = 1)
if (!is.null(self$save_single_paths)) {
self$save_single_paths <- 0
}


# check args only available for lbfgs and bfgs
bfgs_args <- c("init_alpha", "tol_obj", "tol_rel_obj", "tol_grad", "tol_rel_grad", "tol_param")
for (arg in bfgs_args) {
checkmate::assert_number(self[[arg]], .var.name = arg, lower = 0, null.ok = TRUE)
}

if (!is.null(self$history_size)) {
checkmate::assert_integerish(self$history_size, lower = 1, len = 1, null.ok = FALSE)
self$history_size <- as.integer(self$history_size)
}

invisible(TRUE)
}


# Validation helpers ------------------------------------------------------

Expand Down
40 changes: 38 additions & 2 deletions R/csv.R
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ read_cmdstan_csv <- function(files,
"\""
)
} else {
fread_cmd <- paste0("grep -v '^#' --color=never '", output_file, "'")
fread_cmd <- paste0("grep -v '^#' --color=never '", path.expand(output_file), "'")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was this change needed for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I'll revert this change

}
if (length(sampler_diagnostics) > 0) {
post_warmup_sd_id <- length(post_warmup_sampler_diagnostics) + 1
Expand All @@ -280,6 +280,9 @@ read_cmdstan_csv <- function(files,
if (length(variables) > 0) {
draws_list_id <- length(draws) + 1
warmup_draws_list_id <- length(warmup_draws) + 1
if (metadata$method == "pathfinder") {
variables = union(metadata$sampler_diagnostics, metadata$variables)
}
suppressWarnings(
draws[[draws_list_id]] <- data.table::fread(
cmd = fread_cmd,
Expand Down Expand Up @@ -445,6 +448,21 @@ read_cmdstan_csv <- function(files,
metadata = metadata,
generated_quantities = draws
)
} else if (metadata$method == "pathfinder") {
if (is.null(format)) {
format <- "draws_matrix"
}
as_draws_format <- as_draws_format_fun(format)
if (length(draws) == 0) {
pathfinder_draws <- NULL
} else {
pathfinder_draws <- do.call(as_draws_format, list(draws[[1]][-1, colnames(draws[[1]]), drop = FALSE]))
posterior::variables(pathfinder_draws) <- repaired_variables
}
list(
metadata = metadata,
draws = pathfinder_draws
)
}
}

Expand Down Expand Up @@ -477,6 +495,7 @@ as_cmdstan_fit <- function(files, check_diagnostics = TRUE, format = getOption("
"sample" = CmdStanMCMC_CSV$new(csv_contents, files, check_diagnostics),
"optimize" = CmdStanMLE_CSV$new(csv_contents, files),
"variational" = CmdStanVB_CSV$new(csv_contents, files),
"pathfinder" = CmdStanPathfinder_CSV$new(csv_contents, files),
"laplace" = CmdStanLaplace_CSV$new(csv_contents, files)
)
}
Expand Down Expand Up @@ -576,6 +595,23 @@ CmdStanVB_CSV <- R6::R6Class(
private = list(output_files_ = NULL)
)

CmdStanPathfinder_CSV <- R6::R6Class(
classname = "CmdStanPathfinder_CSV",
inherit = CmdStanPathfinder,
public = list(
initialize = function(csv_contents, files) {
private$output_files_ <- files
private$draws_ <- csv_contents$draws
private$metadata_ <- csv_contents$metadata
},
output_files = function(...) {
private$output_files_
}
),
private = list(output_files_ = NULL)
)


# these methods are unavailable because there's no CmdStanRun object
unavailable_methods_CmdStanFit_CSV <- c(
"cmdstan_diagnose", "cmdstan_summary",
Expand Down Expand Up @@ -642,7 +678,7 @@ read_csv_metadata <- function(csv_file) {
"\""
)
} else {
fread_cmd <- paste0("grep '^[#a-zA-Z]' --color=never '", csv_file, "'")
fread_cmd <- paste0("grep '^[#a-zA-Z]' --color=never '", path.expand(csv_file), "'")
}
suppressWarnings(
metadata <- data.table::fread(
Expand Down
2 changes: 1 addition & 1 deletion R/example.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
#'
cmdstanr_example <-
function(example = c("logistic", "schools", "schools_ncp"),
method = c("sample", "optimize", "laplace", "variational", "diagnose"),
method = c("sample", "optimize", "laplace", "variational", "pathfinder", "diagnose"),
...,
quiet = TRUE,
force_recompile = getOption("cmdstanr_force_recompile", default = FALSE)) {
Expand Down
80 changes: 80 additions & 0 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -2024,6 +2024,80 @@ CmdStanVB <- R6::R6Class(
)
CmdStanVB$set("public", name = "lp_approx", value = lp_approx)

# CmdStanPathfinder ---------------------------------------------------------------
#' CmdStanPathfinder objects
#'
#' @name CmdStanPathfinder
#' @family fitted model objects
#' @template seealso-docs
#'
#' @description A `CmdStanPathfinder` object is the fitted model object returned by the
#' [`$pathfinder()`][model-method-pathfinder] method of a
#' [`CmdStanModel`] object.
#'
#' @section Methods: `CmdStanPathfinder` objects have the following associated methods,
#' all of which have their own (linked) documentation pages.
#'
#' ## Extract contents of fitted model object
#'
#' |**Method**|**Description**|
#' |:----------|:---------------|
#' [`$draws()`][fit-method-draws] | Return approximate posterior draws as a [`draws_matrix`][posterior::draws_matrix]. |
#' [`$lp()`][fit-method-lp] | Return the total log probability density (`target`) computed in the model block of the Stan program. |
#' [`$lp_approx()`][fit-method-lp] | Return the log density of the approximation to the posterior. |
#' [`$init()`][fit-method-init] | Return user-specified initial values. |
#' [`$metadata()`][fit-method-metadata] | Return a list of metadata gathered from the CmdStan CSV files. |
#' [`$code()`][fit-method-code] | Return Stan code as a character vector. |
#'
#' ## Summarize inferences
#'
#' |**Method**|**Description**|
#' |:----------|:---------------|
#' [`$summary()`][fit-method-summary] | Run [`posterior::summarise_draws()`][posterior::draws_summary]. |
#' [`$cmdstan_summary()`][fit-method-cmdstan_summary] | Run and print CmdStan's `bin/stansummary`. |
#'
#' ## Save fitted model object and temporary files
#'
#' |**Method**|**Description**|
#' |:----------|:---------------|
#' [`$save_object()`][fit-method-save_object] | Save fitted model object to a file. |
#' [`$save_output_files()`][fit-method-save_output_files] | Save output CSV files to a specified location. |
#' [`$save_data_file()`][fit-method-save_data_file] | Save JSON data file to a specified location. |
#' [`$save_latent_dynamics_files()`][fit-method-save_latent_dynamics_files] | Save diagnostic CSV files to a specified location. |
#'
#' ## Report run times, console output, return codes
#'
#' |**Method**|**Description**|
#' |:----------|:---------------|
#' [`$time()`][fit-method-time] | Report the total run time. |
#' [`$output()`][fit-method-output] | Pretty print the output that was printed to the console. |
#' [`$return_codes()`][fit-method-return_codes] | Return the return codes from the CmdStan runs. |
#'
CmdStanPathfinder <- R6::R6Class(
classname = "CmdStanPathfinder",
inherit = CmdStanFit,
public = list(),
private = list(
# inherits draws_ and metadata_ slots from CmdStanFit
read_csv_ = function(format = getOption("cmdstanr_draws_format", "draws_matrix")) {
if (!length(self$output_files(include_failed = FALSE))) {
stop("Pathfinder failed. Unable to retrieve the draws.", call. = FALSE)
}
csv_contents <- read_cmdstan_csv(self$output_files(), format = format)
private$draws_ <- csv_contents$draws
private$metadata_ <- csv_contents$metadata
invisible(self)
}
)
)

#' @rdname fit-method-lp
lp_approx <- function() {
as.numeric(self$draws()[, "lp_approx__"])
}
CmdStanPathfinder$set("public", name = "lp_approx", value = lp_approx)



# CmdStanGQ ---------------------------------------------------------------
#' CmdStanGQ objects
Expand Down Expand Up @@ -2290,3 +2364,9 @@ as_draws.CmdStanVB <- function(x, ...) {
as_draws.CmdStanGQ <- function(x, ...) {
x$draws(...)
}

#' @rdname as_draws.CmdStanMCMC
#' @export
as_draws.CmdStanPathfinder <- function(x, ...) {
x$draws(...)
}
Loading
Loading