Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

laplace method #800

Merged
merged 28 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
9c2643c
initial attempt at laplace method
jgabry Jul 30, 2023
0b2a677
Update model.R
jgabry Jul 30, 2023
9ea8fd2
Update model-method-laplace.Rd
jgabry Jul 30, 2023
2528523
Merge branch 'master' into laplace-sample
jgabry Jul 30, 2023
04cfb4a
Delete fit-temp.rds
jgabry Jul 30, 2023
3558eda
fix link to optimize method doc
jgabry Jul 30, 2023
e821436
Update model.R
jgabry Jul 30, 2023
64501a0
tests for laplace CmdStanModel method
jgabry Jul 30, 2023
b66f4d1
more tests for laplace method
jgabry Jul 31, 2023
e229fc4
fix doc
jgabry Jul 31, 2023
7984814
a few more tests
jgabry Jul 31, 2023
aa286c3
Update _pkgdown.yml
jgabry Jul 31, 2023
513444c
fix r cmd check warning
jgabry Jul 31, 2023
0c0ce6f
Merge branch 'master' into laplace-sample
jgabry Aug 2, 2023
101810e
change `output_samples` in variational to `draws` for consistency
jgabry Aug 2, 2023
54ec3ad
add laplace section to vignette
jgabry Aug 2, 2023
c9d8fc2
fix vignette error
jgabry Aug 3, 2023
3c3a333
fix failing test
jgabry Aug 3, 2023
18e92f9
Merge branch 'master' into laplace-sample
jgabry Aug 16, 2023
4d172ae
update doc with Aki's suggestion
jgabry Aug 22, 2023
d8eab9e
Merge branch 'master' into laplace-sample
jgabry Aug 25, 2023
5ae78ca
Debug on WSL: Turn off running vignette so unit tests run
jgabry Aug 25, 2023
f447e92
Merge branch 'master' into laplace-sample
jgabry Aug 25, 2023
e13fb1a
undo turning off vignette
jgabry Sep 2, 2023
467054f
Merge branch 'master' into laplace-sample
jgabry Sep 15, 2023
5c2dbbd
Merge branch 'master' into laplace-sample
andrjohns Sep 22, 2023
620458f
Update file path for WSL
andrjohns Sep 25, 2023
1a0a97d
Fix non-WSL asserts
andrjohns Sep 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Generated by roxygen2: do not edit by hand

S3method(as_draws,CmdStanGQ)
S3method(as_draws,CmdStanLaplace)
S3method(as_draws,CmdStanMCMC)
S3method(as_draws,CmdStanMLE)
S3method(as_draws,CmdStanVB)
Expand Down
72 changes: 71 additions & 1 deletion R/args.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#'
#' * `SampleArgs`: stores arguments specific to `method=sample`.
#' * `OptimizeArgs`: stores arguments specific to `method=optimize`.
#' * `LaplaceArgs`: stores arguments specific to `method=laplace`.
#' * `VariationalArgs`: stores arguments specific to `method=variational`
#' * `GenerateQuantitiesArgs`: stores arguments specific to `method=generate_quantities`
#' * `DiagnoseArgs`: stores arguments specific to `method=diagnose`
Expand Down Expand Up @@ -427,6 +428,52 @@ OptimizeArgs <- R6::R6Class(
)


# LaplaceArgs -------------------------------------------------------------

LaplaceArgs <- R6::R6Class(
"LaplaceArgs",
lock_objects = FALSE,
public = list(
method = "laplace",
initialize = function(mode = NULL,
draws = NULL,
jacobian = TRUE) {
checkmate::assert_r6(mode, classes = "CmdStanMLE")
self$mode_object <- mode # keep the CmdStanMLE for later use (can be returned by CmdStanLaplace$mode())
# mode <- file path to pass to CmdStan
# This needs to be a path that can be accessed within WSL
# since the files are used by CmdStan, not R
self$mode <- wsl_safe_path(self$mode_object$output_files())
self$jacobian <- jacobian
self$draws <- draws
invisible(self)
},
validate = function(num_procs) {
validate_laplace_args(self)
invisible(self)
},

# Compose arguments to CmdStan command for laplace-specific
# non-default arguments. Works the same way as compose for sampler args,
# but `idx` is ignored (no multiple chains for optimize or variational)
compose = function(idx = NULL, args = NULL) {
.make_arg <- function(arg_name) {
compose_arg(self, arg_name, idx = NULL)
}
new_args <- list(
"method=laplace",
.make_arg("mode"),
.make_arg("draws"),
.make_arg("jacobian")
)
new_args <- do.call(c, new_args)
c(args, new_args)
}
)
)



# VariationalArgs ---------------------------------------------------------

VariationalArgs <- R6::R6Class(
Expand Down Expand Up @@ -712,6 +759,29 @@ validate_optimize_args <- function(self) {
invisible(TRUE)
}

#' Validate arguments for laplace
#' @noRd
#' @param self A `LaplaceArgs` object.
#' @return `TRUE` invisibly unless an error is thrown.
validate_laplace_args <- function(self) {
assert_file_exists(self$mode, extension = "csv")
checkmate::assert_integerish(self$draws, lower = 1, null.ok = TRUE, len = 1)
if (!is.null(self$draws)) {
self$draws <- as.integer(self$draws)
}
checkmate::assert_flag(self$jacobian, null.ok = FALSE)
if (self$mode_object$metadata()$jacobian != self$jacobian) {
stop(
"'jacobian' argument to optimize and laplace must match!\n",
"laplace was called with jacobian=", self$jacobian, "\n",
"optimize was run with jacobian=", as.logical(self$mode_object$metadata()$jacobian),
call. = FALSE
)
}
self$jacobian <- as.integer(self$jacobian)
invisible(TRUE)
}

#' Validate arguments for standalone generated quantities
#' @noRd
#' @param self A `GenerateQuantitiesArgs` object.
Expand Down Expand Up @@ -764,7 +834,7 @@ validate_variational_args <- function(self) {
self$eval_elbo <- as.integer(self$eval_elbo)
}
checkmate::assert_integerish(self$output_samples, null.ok = TRUE,
lower = 1, len = 1)
lower = 1, len = 1, .var.name = "draws")
if (!is.null(self$output_samples)) {
self$output_samples <- as.integer(self$output_samples)
}
Expand Down
57 changes: 52 additions & 5 deletions R/csv.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#' and memory for models with many parameters.
#'
#' @return
#' `as_cmdstan_fit()` returns a [CmdStanMCMC], [CmdStanMLE], or
#' `as_cmdstan_fit()` returns a [CmdStanMCMC], [CmdStanMLE], [CmdStanLaplace] or
#' [CmdStanVB] object. Some methods typically defined for those objects will not
#' work (e.g. `save_data_file()`) but the important methods like `$summary()`,
#' `$draws()`, `$sampler_diagnostics()` and others will work fine.
Expand Down Expand Up @@ -67,7 +67,8 @@
#'
#' * `point_estimates`: Point estimates for the model parameters.
#'
#' For [variational inference][model-method-variational] the returned list also
#' For [laplace][model-method-laplace] and
#' [variational inference][model-method-variational] the returned list also
#' includes the following components:
#'
#' * `draws`: A [`draws_matrix`][posterior::draws_matrix] (or different format
Expand Down Expand Up @@ -307,6 +308,11 @@ read_cmdstan_csv <- function(files,
repaired_variables <- repaired_variables[repaired_variables != "lp__"]
repaired_variables <- gsub("log_p__", "lp__", repaired_variables)
repaired_variables <- gsub("log_g__", "lp_approx__", repaired_variables)
} else if (metadata$method == "laplace") {
metadata$variables <- gsub("log_p__", "lp__", metadata$variables)
metadata$variables <- gsub("log_q__", "lp_approx__", metadata$variables)
repaired_variables <- gsub("log_p__", "lp__", repaired_variables)
repaired_variables <- gsub("log_q__", "lp_approx__", repaired_variables)
}
model_param_dims <- variable_dims(metadata$variables)
metadata$stan_variable_sizes <- model_param_dims
Expand Down Expand Up @@ -385,6 +391,29 @@ read_cmdstan_csv <- function(files,
metadata = metadata,
draws = variational_draws
)
} else if (metadata$method == "laplace") {
if (is.null(format)) {
format <- "draws_matrix"
}
as_draws_format <- as_draws_format_fun(format)
if (length(draws) == 0) {
laplace_draws <- NULL
} else {
laplace_draws <- do.call(as_draws_format, list(draws[[1]]))
}
if (!is.null(laplace_draws)) {
if ("log_p__" %in% posterior::variables(laplace_draws)) {
laplace_draws <- posterior::rename_variables(laplace_draws, lp__ = "log_p__")
}
if ("log_q__" %in% posterior::variables(laplace_draws)) {
laplace_draws <- posterior::rename_variables(laplace_draws, lp_approx__ = "log_q__")
}
posterior::variables(laplace_draws) <- repaired_variables
}
list(
metadata = metadata,
draws = laplace_draws
)
} else if (metadata$method == "optimize") {
if (is.null(format)) {
format <- "draws_matrix"
Expand Down Expand Up @@ -447,7 +476,8 @@ as_cmdstan_fit <- function(files, check_diagnostics = TRUE, format = getOption("
csv_contents$metadata$method,
"sample" = CmdStanMCMC_CSV$new(csv_contents, files, check_diagnostics),
"optimize" = CmdStanMLE_CSV$new(csv_contents, files),
"variational" = CmdStanVB_CSV$new(csv_contents, files)
"variational" = CmdStanVB_CSV$new(csv_contents, files),
"laplace" = CmdStanLaplace_CSV$new(csv_contents, files)
)
}

Expand Down Expand Up @@ -513,6 +543,22 @@ CmdStanMLE_CSV <- R6::R6Class(
),
private = list(output_files_ = NULL)
)
CmdStanLaplace_CSV <- R6::R6Class(
classname = "CmdStanLaplace_CSV",
inherit = CmdStanLaplace,
public = list(
initialize = function(csv_contents, files) {
private$output_files_ <- files
private$draws_ <- csv_contents$draws
private$metadata_ <- csv_contents$metadata
invisible(self)
},
output_files = function(...) {
private$output_files_
}
),
private = list(output_files_ = NULL)
)
CmdStanVB_CSV <- R6::R6Class(
classname = "CmdStanVB_CSV",
inherit = CmdStanVB,
Expand Down Expand Up @@ -554,6 +600,7 @@ for (method in unavailable_methods_CmdStanFit_CSV) {
}
CmdStanMLE_CSV$set("public", name = method, value = error_unavailable_CmdStanFit_CSV)
CmdStanVB_CSV$set("public", name = method, value = error_unavailable_CmdStanFit_CSV)
CmdStanLaplace_CSV$set("public", name = method, value = error_unavailable_CmdStanFit_CSV)
}


Expand Down Expand Up @@ -616,7 +663,7 @@ read_csv_metadata <- function(csv_file) {
all_names <- strsplit(line, ",")[[1]]
if (all(csv_file_info$algorithm != "fixed_param")) {
csv_file_info[["sampler_diagnostics"]] <- all_names[endsWith(all_names, "__")]
csv_file_info[["sampler_diagnostics"]] <- csv_file_info[["sampler_diagnostics"]][!(csv_file_info[["sampler_diagnostics"]] %in% c("lp__", "log_p__", "log_g__"))]
csv_file_info[["sampler_diagnostics"]] <- csv_file_info[["sampler_diagnostics"]][!(csv_file_info[["sampler_diagnostics"]] %in% c("lp__", "log_p__", "log_g__", "log_q__"))]
csv_file_info[["variables"]] <- all_names[!(all_names %in% csv_file_info[["sampler_diagnostics"]])]
} else {
csv_file_info[["variables"]] <- all_names[!endsWith(all_names, "__")]
Expand Down Expand Up @@ -719,7 +766,7 @@ read_csv_metadata <- function(csv_file) {
csv_file_info$step_size <- csv_file_info$stepsize
csv_file_info$iter_warmup <- csv_file_info$num_warmup
csv_file_info$iter_sampling <- csv_file_info$num_samples
if (csv_file_info$method == "variational" || csv_file_info$method == "optimize") {
if (csv_file_info$method %in% c("variational", "optimize", "laplace")) {
csv_file_info$threads <- csv_file_info$num_threads
} else {
csv_file_info$threads_per_chain <- csv_file_info$num_threads
Expand Down
2 changes: 1 addition & 1 deletion R/example.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
#'
cmdstanr_example <-
function(example = c("logistic", "schools", "schools_ncp"),
method = c("sample", "optimize", "variational", "diagnose"),
method = c("sample", "optimize", "laplace", "variational", "diagnose"),
...,
quiet = TRUE,
force_recompile = getOption("cmdstanr_force_recompile", default = FALSE)) {
Expand Down
Loading