Skip to content

Commit

Permalink
Merge pull request #800 from stan-dev/laplace-sample
Browse files Browse the repository at this point in the history
laplace method
  • Loading branch information
jgabry authored Sep 25, 2023
2 parents 7d3d6fa + 1a0a97d commit a63e418
Show file tree
Hide file tree
Showing 46 changed files with 2,246 additions and 141 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Generated by roxygen2: do not edit by hand

S3method(as_draws,CmdStanGQ)
S3method(as_draws,CmdStanLaplace)
S3method(as_draws,CmdStanMCMC)
S3method(as_draws,CmdStanMLE)
S3method(as_draws,CmdStanVB)
Expand Down
72 changes: 71 additions & 1 deletion R/args.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#'
#' * `SampleArgs`: stores arguments specific to `method=sample`.
#' * `OptimizeArgs`: stores arguments specific to `method=optimize`.
#' * `LaplaceArgs`: stores arguments specific to `method=laplace`.
#' * `VariationalArgs`: stores arguments specific to `method=variational`
#' * `GenerateQuantitiesArgs`: stores arguments specific to `method=generate_quantities`
#' * `DiagnoseArgs`: stores arguments specific to `method=diagnose`
Expand Down Expand Up @@ -427,6 +428,52 @@ OptimizeArgs <- R6::R6Class(
)


# LaplaceArgs -------------------------------------------------------------

LaplaceArgs <- R6::R6Class(
"LaplaceArgs",
lock_objects = FALSE,
public = list(
method = "laplace",
initialize = function(mode = NULL,
draws = NULL,
jacobian = TRUE) {
checkmate::assert_r6(mode, classes = "CmdStanMLE")
self$mode_object <- mode # keep the CmdStanMLE for later use (can be returned by CmdStanLaplace$mode())
# mode <- file path to pass to CmdStan
# This needs to be a path that can be accessed within WSL
# since the files are used by CmdStan, not R
self$mode <- wsl_safe_path(self$mode_object$output_files())
self$jacobian <- jacobian
self$draws <- draws
invisible(self)
},
validate = function(num_procs) {
validate_laplace_args(self)
invisible(self)
},

# Compose arguments to CmdStan command for laplace-specific
# non-default arguments. Works the same way as compose for sampler args,
# but `idx` is ignored (no multiple chains for optimize or variational)
compose = function(idx = NULL, args = NULL) {
.make_arg <- function(arg_name) {
compose_arg(self, arg_name, idx = NULL)
}
new_args <- list(
"method=laplace",
.make_arg("mode"),
.make_arg("draws"),
.make_arg("jacobian")
)
new_args <- do.call(c, new_args)
c(args, new_args)
}
)
)



# VariationalArgs ---------------------------------------------------------

VariationalArgs <- R6::R6Class(
Expand Down Expand Up @@ -712,6 +759,29 @@ validate_optimize_args <- function(self) {
invisible(TRUE)
}

#' Validate arguments for laplace
#' @noRd
#' @param self A `LaplaceArgs` object.
#' @return `TRUE` invisibly unless an error is thrown.
validate_laplace_args <- function(self) {
assert_file_exists(self$mode, extension = "csv")
checkmate::assert_integerish(self$draws, lower = 1, null.ok = TRUE, len = 1)
if (!is.null(self$draws)) {
self$draws <- as.integer(self$draws)
}
checkmate::assert_flag(self$jacobian, null.ok = FALSE)
if (self$mode_object$metadata()$jacobian != self$jacobian) {
stop(
"'jacobian' argument to optimize and laplace must match!\n",
"laplace was called with jacobian=", self$jacobian, "\n",
"optimize was run with jacobian=", as.logical(self$mode_object$metadata()$jacobian),
call. = FALSE
)
}
self$jacobian <- as.integer(self$jacobian)
invisible(TRUE)
}

#' Validate arguments for standalone generated quantities
#' @noRd
#' @param self A `GenerateQuantitiesArgs` object.
Expand Down Expand Up @@ -764,7 +834,7 @@ validate_variational_args <- function(self) {
self$eval_elbo <- as.integer(self$eval_elbo)
}
checkmate::assert_integerish(self$output_samples, null.ok = TRUE,
lower = 1, len = 1)
lower = 1, len = 1, .var.name = "draws")
if (!is.null(self$output_samples)) {
self$output_samples <- as.integer(self$output_samples)
}
Expand Down
57 changes: 52 additions & 5 deletions R/csv.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#' and memory for models with many parameters.
#'
#' @return
#' `as_cmdstan_fit()` returns a [CmdStanMCMC], [CmdStanMLE], or
#' `as_cmdstan_fit()` returns a [CmdStanMCMC], [CmdStanMLE], [CmdStanLaplace] or
#' [CmdStanVB] object. Some methods typically defined for those objects will not
#' work (e.g. `save_data_file()`) but the important methods like `$summary()`,
#' `$draws()`, `$sampler_diagnostics()` and others will work fine.
Expand Down Expand Up @@ -67,7 +67,8 @@
#'
#' * `point_estimates`: Point estimates for the model parameters.
#'
#' For [variational inference][model-method-variational] the returned list also
#' For [laplace][model-method-laplace] and
#' [variational inference][model-method-variational] the returned list also
#' includes the following components:
#'
#' * `draws`: A [`draws_matrix`][posterior::draws_matrix] (or different format
Expand Down Expand Up @@ -307,6 +308,11 @@ read_cmdstan_csv <- function(files,
repaired_variables <- repaired_variables[repaired_variables != "lp__"]
repaired_variables <- gsub("log_p__", "lp__", repaired_variables)
repaired_variables <- gsub("log_g__", "lp_approx__", repaired_variables)
} else if (metadata$method == "laplace") {
metadata$variables <- gsub("log_p__", "lp__", metadata$variables)
metadata$variables <- gsub("log_q__", "lp_approx__", metadata$variables)
repaired_variables <- gsub("log_p__", "lp__", repaired_variables)
repaired_variables <- gsub("log_q__", "lp_approx__", repaired_variables)
}
model_param_dims <- variable_dims(metadata$variables)
metadata$stan_variable_sizes <- model_param_dims
Expand Down Expand Up @@ -385,6 +391,29 @@ read_cmdstan_csv <- function(files,
metadata = metadata,
draws = variational_draws
)
} else if (metadata$method == "laplace") {
if (is.null(format)) {
format <- "draws_matrix"
}
as_draws_format <- as_draws_format_fun(format)
if (length(draws) == 0) {
laplace_draws <- NULL
} else {
laplace_draws <- do.call(as_draws_format, list(draws[[1]]))
}
if (!is.null(laplace_draws)) {
if ("log_p__" %in% posterior::variables(laplace_draws)) {
laplace_draws <- posterior::rename_variables(laplace_draws, lp__ = "log_p__")
}
if ("log_q__" %in% posterior::variables(laplace_draws)) {
laplace_draws <- posterior::rename_variables(laplace_draws, lp_approx__ = "log_q__")
}
posterior::variables(laplace_draws) <- repaired_variables
}
list(
metadata = metadata,
draws = laplace_draws
)
} else if (metadata$method == "optimize") {
if (is.null(format)) {
format <- "draws_matrix"
Expand Down Expand Up @@ -447,7 +476,8 @@ as_cmdstan_fit <- function(files, check_diagnostics = TRUE, format = getOption("
csv_contents$metadata$method,
"sample" = CmdStanMCMC_CSV$new(csv_contents, files, check_diagnostics),
"optimize" = CmdStanMLE_CSV$new(csv_contents, files),
"variational" = CmdStanVB_CSV$new(csv_contents, files)
"variational" = CmdStanVB_CSV$new(csv_contents, files),
"laplace" = CmdStanLaplace_CSV$new(csv_contents, files)
)
}

Expand Down Expand Up @@ -513,6 +543,22 @@ CmdStanMLE_CSV <- R6::R6Class(
),
private = list(output_files_ = NULL)
)
CmdStanLaplace_CSV <- R6::R6Class(
classname = "CmdStanLaplace_CSV",
inherit = CmdStanLaplace,
public = list(
initialize = function(csv_contents, files) {
private$output_files_ <- files
private$draws_ <- csv_contents$draws
private$metadata_ <- csv_contents$metadata
invisible(self)
},
output_files = function(...) {
private$output_files_
}
),
private = list(output_files_ = NULL)
)
CmdStanVB_CSV <- R6::R6Class(
classname = "CmdStanVB_CSV",
inherit = CmdStanVB,
Expand Down Expand Up @@ -554,6 +600,7 @@ for (method in unavailable_methods_CmdStanFit_CSV) {
}
CmdStanMLE_CSV$set("public", name = method, value = error_unavailable_CmdStanFit_CSV)
CmdStanVB_CSV$set("public", name = method, value = error_unavailable_CmdStanFit_CSV)
CmdStanLaplace_CSV$set("public", name = method, value = error_unavailable_CmdStanFit_CSV)
}


Expand Down Expand Up @@ -616,7 +663,7 @@ read_csv_metadata <- function(csv_file) {
all_names <- strsplit(line, ",")[[1]]
if (all(csv_file_info$algorithm != "fixed_param")) {
csv_file_info[["sampler_diagnostics"]] <- all_names[endsWith(all_names, "__")]
csv_file_info[["sampler_diagnostics"]] <- csv_file_info[["sampler_diagnostics"]][!(csv_file_info[["sampler_diagnostics"]] %in% c("lp__", "log_p__", "log_g__"))]
csv_file_info[["sampler_diagnostics"]] <- csv_file_info[["sampler_diagnostics"]][!(csv_file_info[["sampler_diagnostics"]] %in% c("lp__", "log_p__", "log_g__", "log_q__"))]
csv_file_info[["variables"]] <- all_names[!(all_names %in% csv_file_info[["sampler_diagnostics"]])]
} else {
csv_file_info[["variables"]] <- all_names[!endsWith(all_names, "__")]
Expand Down Expand Up @@ -719,7 +766,7 @@ read_csv_metadata <- function(csv_file) {
csv_file_info$step_size <- csv_file_info$stepsize
csv_file_info$iter_warmup <- csv_file_info$num_warmup
csv_file_info$iter_sampling <- csv_file_info$num_samples
if (csv_file_info$method == "variational" || csv_file_info$method == "optimize") {
if (csv_file_info$method %in% c("variational", "optimize", "laplace")) {
csv_file_info$threads <- csv_file_info$num_threads
} else {
csv_file_info$threads_per_chain <- csv_file_info$num_threads
Expand Down
2 changes: 1 addition & 1 deletion R/example.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
#'
cmdstanr_example <-
function(example = c("logistic", "schools", "schools_ncp"),
method = c("sample", "optimize", "variational", "diagnose"),
method = c("sample", "optimize", "laplace", "variational", "diagnose"),
...,
quiet = TRUE,
force_recompile = getOption("cmdstanr_force_recompile", default = FALSE)) {
Expand Down
Loading

0 comments on commit a63e418

Please sign in to comment.