diff --git a/NAMESPACE b/NAMESPACE index 832c10280..4ab6a168e 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,6 +1,7 @@ # Generated by roxygen2: do not edit by hand S3method(as_draws,CmdStanGQ) +S3method(as_draws,CmdStanLaplace) S3method(as_draws,CmdStanMCMC) S3method(as_draws,CmdStanMLE) S3method(as_draws,CmdStanVB) diff --git a/R/args.R b/R/args.R index 2d622e86c..31841f048 100644 --- a/R/args.R +++ b/R/args.R @@ -14,6 +14,7 @@ #' #' * `SampleArgs`: stores arguments specific to `method=sample`. #' * `OptimizeArgs`: stores arguments specific to `method=optimize`. +#' * `LaplaceArgs`: stores arguments specific to `method=laplace`. #' * `VariationalArgs`: stores arguments specific to `method=variational` #' * `PathfinderArgs`: stores arguments specific to `method=pathfinder` #' * `GenerateQuantitiesArgs`: stores arguments specific to `method=generate_quantities` @@ -433,6 +434,52 @@ OptimizeArgs <- R6::R6Class( ) +# LaplaceArgs ------------------------------------------------------------- + +LaplaceArgs <- R6::R6Class( + "LaplaceArgs", + lock_objects = FALSE, + public = list( + method = "laplace", + initialize = function(mode = NULL, + draws = NULL, + jacobian = TRUE) { + checkmate::assert_r6(mode, classes = "CmdStanMLE") + self$mode_object <- mode # keep the CmdStanMLE for later use (can be returned by CmdStanLaplace$mode()) + # mode <- file path to pass to CmdStan + # This needs to be a path that can be accessed within WSL + # since the files are used by CmdStan, not R + self$mode <- wsl_safe_path(self$mode_object$output_files()) + self$jacobian <- jacobian + self$draws <- draws + invisible(self) + }, + validate = function(num_procs) { + validate_laplace_args(self) + invisible(self) + }, + + # Compose arguments to CmdStan command for laplace-specific + # non-default arguments. Works the same way as compose for sampler args, + # but `idx` is ignored (no multiple chains for optimize or variational) + compose = function(idx = NULL, args = NULL) { + .make_arg <- function(arg_name) { + compose_arg(self, arg_name, idx = NULL) + } + new_args <- list( + "method=laplace", + .make_arg("mode"), + .make_arg("draws"), + .make_arg("jacobian") + ) + new_args <- do.call(c, new_args) + c(args, new_args) + } + ) +) + + + # VariationalArgs --------------------------------------------------------- VariationalArgs <- R6::R6Class( @@ -784,6 +831,29 @@ validate_optimize_args <- function(self) { invisible(TRUE) } +#' Validate arguments for laplace +#' @noRd +#' @param self A `LaplaceArgs` object. +#' @return `TRUE` invisibly unless an error is thrown. +validate_laplace_args <- function(self) { + assert_file_exists(self$mode, extension = "csv") + checkmate::assert_integerish(self$draws, lower = 1, null.ok = TRUE, len = 1) + if (!is.null(self$draws)) { + self$draws <- as.integer(self$draws) + } + checkmate::assert_flag(self$jacobian, null.ok = FALSE) + if (self$mode_object$metadata()$jacobian != self$jacobian) { + stop( + "'jacobian' argument to optimize and laplace must match!\n", + "laplace was called with jacobian=", self$jacobian, "\n", + "optimize was run with jacobian=", as.logical(self$mode_object$metadata()$jacobian), + call. = FALSE + ) + } + self$jacobian <- as.integer(self$jacobian) + invisible(TRUE) +} + #' Validate arguments for standalone generated quantities #' @noRd #' @param self A `GenerateQuantitiesArgs` object. @@ -836,7 +906,7 @@ validate_variational_args <- function(self) { self$eval_elbo <- as.integer(self$eval_elbo) } checkmate::assert_integerish(self$output_samples, null.ok = TRUE, - lower = 1, len = 1) + lower = 1, len = 1, .var.name = "draws") if (!is.null(self$output_samples)) { self$output_samples <- as.integer(self$output_samples) } diff --git a/R/csv.R b/R/csv.R index b5af34a9e..9bb2320be 100644 --- a/R/csv.R +++ b/R/csv.R @@ -27,7 +27,7 @@ #' and memory for models with many parameters. #' #' @return -#' `as_cmdstan_fit()` returns a [CmdStanMCMC], [CmdStanMLE], or +#' `as_cmdstan_fit()` returns a [CmdStanMCMC], [CmdStanMLE], [CmdStanLaplace] or #' [CmdStanVB] object. Some methods typically defined for those objects will not #' work (e.g. `save_data_file()`) but the important methods like `$summary()`, #' `$draws()`, `$sampler_diagnostics()` and others will work fine. @@ -67,7 +67,8 @@ #' #' * `point_estimates`: Point estimates for the model parameters. #' -#' For [variational inference][model-method-variational] the returned list also +#' For [laplace][model-method-laplace] and +#' [variational inference][model-method-variational] the returned list also #' includes the following components: #' #' * `draws`: A [`draws_matrix`][posterior::draws_matrix] (or different format @@ -310,6 +311,11 @@ read_cmdstan_csv <- function(files, repaired_variables <- repaired_variables[repaired_variables != "lp__"] repaired_variables <- gsub("log_p__", "lp__", repaired_variables) repaired_variables <- gsub("log_g__", "lp_approx__", repaired_variables) + } else if (metadata$method == "laplace") { + metadata$variables <- gsub("log_p__", "lp__", metadata$variables) + metadata$variables <- gsub("log_q__", "lp_approx__", metadata$variables) + repaired_variables <- gsub("log_p__", "lp__", repaired_variables) + repaired_variables <- gsub("log_q__", "lp_approx__", repaired_variables) } model_param_dims <- variable_dims(metadata$variables) metadata$stan_variable_sizes <- model_param_dims @@ -388,6 +394,29 @@ read_cmdstan_csv <- function(files, metadata = metadata, draws = variational_draws ) + } else if (metadata$method == "laplace") { + if (is.null(format)) { + format <- "draws_matrix" + } + as_draws_format <- as_draws_format_fun(format) + if (length(draws) == 0) { + laplace_draws <- NULL + } else { + laplace_draws <- do.call(as_draws_format, list(draws[[1]])) + } + if (!is.null(laplace_draws)) { + if ("log_p__" %in% posterior::variables(laplace_draws)) { + laplace_draws <- posterior::rename_variables(laplace_draws, lp__ = "log_p__") + } + if ("log_q__" %in% posterior::variables(laplace_draws)) { + laplace_draws <- posterior::rename_variables(laplace_draws, lp_approx__ = "log_q__") + } + posterior::variables(laplace_draws) <- repaired_variables + } + list( + metadata = metadata, + draws = laplace_draws + ) } else if (metadata$method == "optimize") { if (is.null(format)) { format <- "draws_matrix" @@ -466,7 +495,8 @@ as_cmdstan_fit <- function(files, check_diagnostics = TRUE, format = getOption(" "sample" = CmdStanMCMC_CSV$new(csv_contents, files, check_diagnostics), "optimize" = CmdStanMLE_CSV$new(csv_contents, files), "variational" = CmdStanVB_CSV$new(csv_contents, files), - "pathfinder" = CmdStanPathfinder_CSV$new(csv_contents, files) + "pathfinder" = CmdStanPathfinder_CSV$new(csv_contents, files), + "laplace" = CmdStanLaplace_CSV$new(csv_contents, files) ) } @@ -532,6 +562,22 @@ CmdStanMLE_CSV <- R6::R6Class( ), private = list(output_files_ = NULL) ) +CmdStanLaplace_CSV <- R6::R6Class( + classname = "CmdStanLaplace_CSV", + inherit = CmdStanLaplace, + public = list( + initialize = function(csv_contents, files) { + private$output_files_ <- files + private$draws_ <- csv_contents$draws + private$metadata_ <- csv_contents$metadata + invisible(self) + }, + output_files = function(...) { + private$output_files_ + } + ), + private = list(output_files_ = NULL) +) CmdStanVB_CSV <- R6::R6Class( classname = "CmdStanVB_CSV", inherit = CmdStanVB, @@ -590,6 +636,7 @@ for (method in unavailable_methods_CmdStanFit_CSV) { } CmdStanMLE_CSV$set("public", name = method, value = error_unavailable_CmdStanFit_CSV) CmdStanVB_CSV$set("public", name = method, value = error_unavailable_CmdStanFit_CSV) + CmdStanLaplace_CSV$set("public", name = method, value = error_unavailable_CmdStanFit_CSV) } @@ -652,7 +699,7 @@ read_csv_metadata <- function(csv_file) { all_names <- strsplit(line, ",")[[1]] if (all(csv_file_info$algorithm != "fixed_param")) { csv_file_info[["sampler_diagnostics"]] <- all_names[endsWith(all_names, "__")] - csv_file_info[["sampler_diagnostics"]] <- csv_file_info[["sampler_diagnostics"]][!(csv_file_info[["sampler_diagnostics"]] %in% c("lp__", "log_p__", "log_g__"))] + csv_file_info[["sampler_diagnostics"]] <- csv_file_info[["sampler_diagnostics"]][!(csv_file_info[["sampler_diagnostics"]] %in% c("lp__", "log_p__", "log_g__", "log_q__"))] csv_file_info[["variables"]] <- all_names[!(all_names %in% csv_file_info[["sampler_diagnostics"]])] } else { csv_file_info[["variables"]] <- all_names[!endsWith(all_names, "__")] @@ -755,7 +802,7 @@ read_csv_metadata <- function(csv_file) { csv_file_info$step_size <- csv_file_info$stepsize csv_file_info$iter_warmup <- csv_file_info$num_warmup csv_file_info$iter_sampling <- csv_file_info$num_samples - if (csv_file_info$method == "variational" || csv_file_info$method == "optimize") { + if (csv_file_info$method %in% c("variational", "optimize", "laplace")) { csv_file_info$threads <- csv_file_info$num_threads } else { csv_file_info$threads_per_chain <- csv_file_info$num_threads diff --git a/R/example.R b/R/example.R index 63dc9ed56..5bb98ae3e 100644 --- a/R/example.R +++ b/R/example.R @@ -50,7 +50,7 @@ #' cmdstanr_example <- function(example = c("logistic", "schools", "schools_ncp"), - method = c("sample", "optimize", "variational", "pathfinder", "diagnose"), + method = c("sample", "optimize", "laplace", "variational", "pathfinder", "diagnose"), ..., quiet = TRUE, force_recompile = getOption("cmdstanr_force_recompile", default = FALSE)) { diff --git a/R/fit.R b/R/fit.R index 577891c70..164e420ec 100644 --- a/R/fit.R +++ b/R/fit.R @@ -2,8 +2,9 @@ #' CmdStanFit superclass #' #' @noRd -#' @description CmdStanMCMC, CmdStanMLE, CmdStanVB, CmdStanGQ all share the -#' methods of the superclass CmdStanFit and also have their own unique methods. +#' @description CmdStanMCMC, CmdStanMLE, CmdStanLaplace, CmdStanVB, CmdStanGQ +#' all share the methods of the superclass CmdStanFit and also have their own +#' unique methods. #' CmdStanFit <- R6::R6Class( classname = "CmdStanFit", @@ -693,10 +694,12 @@ CmdStanFit$set("public", name = "constrain_variables", value = constrain_variabl #' @description The `$lp()` method extracts `lp__`, the total log probability #' (`target`) accumulated in the model block of the Stan program. For #' variational inference the log density of the variational approximation to -#' the posterior is also available via the `$lp_approx()` method. +#' the posterior is available via the `$lp_approx()` method. For +#' Laplace approximation the unnormalized density of the approximation to +#' the posterior is available via the `$lp_approx()` method. #' #' See the [Log Probability Increment vs. Sampling -#' Statement](https://mc-stan.org/docs/2_23/reference-manual/sampling-statements-section.html) +#' Statement](https://mc-stan.org/docs/reference-manual/sampling-statements.html) #' section of the Stan Reference Manual for details on when normalizing #' constants are dropped from log probability calculations. #' @@ -707,21 +710,24 @@ CmdStanFit$set("public", name = "constrain_variables", value = constrain_variabl #' evaluated at a posterior draw (which is on the constrained space). `lp__` is #' intended to diagnose sampling efficiency and evaluate approximations. #' -#' `lp_approx__` is the log density of the variational approximation to `lp__` -#' (also on the unconstrained space). It is exposed in the variational method -#' for performing the checks described in Yao et al. (2018) and implemented in -#' the \pkg{loo} package. +#' For variational inference `lp_approx__` is the log density of the variational +#' approximation to `lp__` (also on the unconstrained space). It is exposed in +#' the variational method for performing the checks described in Yao et al. +#' (2018) and implemented in the \pkg{loo} package. +#' +#' For Laplace approximation `lp_approx__` is the unnormalized density of the +#' Laplace approximation. It can be used to perform the same checks as in the +#' case of the variational method described in Yao et al. (2018). #' #' @return A numeric vector with length equal to the number of (post-warmup) -#' draws for MCMC and variational inference, and length equal to `1` for -#' optimization. +#' draws or length equal to `1` for optimization. #' #' @references #' Yao, Y., Vehtari, A., Simpson, D., and Gelman, A. (2018). Yes, but did it #' work?: Evaluating variational inference. *Proceedings of the 35th #' International Conference on Machine Learning*, PMLR 80:5581–5590. #' -#' @seealso [`CmdStanMCMC`], [`CmdStanMLE`], [`CmdStanVB`] +#' @seealso [`CmdStanMCMC`], [`CmdStanMLE`], [`CmdStanLaplace`], [`CmdStanVB`] #' #' @examples #' \dontrun{ @@ -742,6 +748,13 @@ lp <- function() { } CmdStanFit$set("public", name = "lp", value = lp) +# will be used by a subset of fit objects below +#' @rdname fit-method-lp +lp_approx <- function() { + as.numeric(self$draws()[, "lp_approx__"]) +} + + #' Compute a summary table of estimates and diagnostics #' #' @name fit-method-summary @@ -769,7 +782,7 @@ CmdStanFit$set("public", name = "lp", value = lp) #' The `$print()` method returns the fitted model object itself (invisibly), #' which is the standard behavior for print methods in \R. #' -#' @seealso [`CmdStanMCMC`], [`CmdStanMLE`], [`CmdStanVB`], [`CmdStanGQ`] +#' @seealso [`CmdStanMCMC`], [`CmdStanMLE`], [`CmdStanLaplace`], [`CmdStanVB`], [`CmdStanGQ`] #' #' @examples #' \dontrun{ @@ -1008,7 +1021,9 @@ CmdStanFit$set("public", name = "data_file", value = data_file) #' @aliases time #' @description Report the run time in seconds. For MCMC additional information #' is provided about the run times of individual chains and the warmup and -#' sampling phases. +#' sampling phases. For Laplace approximation the time only include the time +#' for drawing the approximate sample and does not include the time +#' taken to run the `$optimize()` method. #' #' @return #' A list with elements @@ -1025,11 +1040,16 @@ CmdStanFit$set("public", name = "data_file", value = data_file) #' fit_mcmc <- cmdstanr_example("logistic", method = "sample") #' fit_mcmc$time() #' -#' fit_mle <- cmdstanr_example("logistic", method = "optimize") -#' fit_mle$time() -#' #' fit_vb <- cmdstanr_example("logistic", method = "variational") #' fit_vb$time() +#' +#' fit_mle <- cmdstanr_example("logistic", method = "optimize", jacobian = TRUE) +#' fit_mle$time() +#' +#' # use fit_mle to draw samples from laplace approximation +#' fit_laplace <- cmdstanr_example("logistic", method = "laplace", mode = fit_mle) +#' fit_laplace$time() # just time for drawing sample not for running optimize +#' fit_laplace$time()$total + fit_mle$time()$total # total time #' } #' time <- function() { @@ -1851,6 +1871,76 @@ mle <- function(variables = NULL) { } CmdStanMLE$set("public", name = "mle", value = mle) +# CmdStanLaplace --------------------------------------------------------------- +#' CmdStanLaplace objects +#' +#' @name CmdStanLaplace +#' @family fitted model objects +#' @template seealso-docs +#' +#' @description A `CmdStanLaplace` object is the fitted model object returned by the +#' [`$laplace()`][model-method-laplace] method of a +#' [`CmdStanModel`] object. +#' +#' @section Methods: `CmdStanLaplace` objects have the following associated methods, +#' all of which have their own (linked) documentation pages. +#' +#' ## Extract contents of fitted model object +#' +#' |**Method**|**Description**| +#' |:----------|:---------------| +#' [`$draws()`][fit-method-draws] | Return approximate posterior draws as a [`draws_matrix`][posterior::draws_matrix]. | +#' `$mode()` | Return the mode as a [`CmdStanMLE`] object. | +#' [`$lp()`][fit-method-lp] | Return the total log probability density (`target`) computed in the model block of the Stan program. | +#' [`$lp_approx()`][fit-method-lp] | Return the log density of the approximation to the posterior. | +#' [`$init()`][fit-method-init] | Return user-specified initial values. | +#' [`$metadata()`][fit-method-metadata] | Return a list of metadata gathered from the CmdStan CSV files. | +#' [`$code()`][fit-method-code] | Return Stan code as a character vector. | +#' +#' ## Summarize inferences +#' +#' |**Method**|**Description**| +#' |:----------|:---------------| +#' [`$summary()`][fit-method-summary] | Run [`posterior::summarise_draws()`][posterior::draws_summary]. | +#' +#' ## Save fitted model object and temporary files +#' +#' |**Method**|**Description**| +#' |:----------|:---------------| +#' [`$save_object()`][fit-method-save_object] | Save fitted model object to a file. | +#' [`$save_output_files()`][fit-method-save_output_files] | Save output CSV files to a specified location. | +#' [`$save_data_file()`][fit-method-save_data_file] | Save JSON data file to a specified location. | +#' [`$save_latent_dynamics_files()`][fit-method-save_latent_dynamics_files] | Save diagnostic CSV files to a specified location. | +#' +#' ## Report run times, console output, return codes +#' +#' |**Method**|**Description**| +#' |:----------|:---------------| +#' [`$time()`][fit-method-time] | Report the run time of the Laplace sampling step. | +#' [`$output()`][fit-method-output] | Pretty print the output that was printed to the console. | +#' [`$return_codes()`][fit-method-return_codes] | Return the return codes from the CmdStan runs. | +#' +CmdStanLaplace <- R6::R6Class( + classname = "CmdStanLaplace", + inherit = CmdStanFit, + public = list( + mode = function() self$runset$args$method_args$mode_object + ), + private = list( + # inherits draws_ and metadata_ slots from CmdStanFit + read_csv_ = function(format = getOption("cmdstanr_draws_format", "draws_matrix")) { + if (!length(self$output_files(include_failed = FALSE))) { + stop("Laplace inference failed. Unable to retrieve the draws.", call. = FALSE) + } + csv_contents <- read_cmdstan_csv(self$output_files(), format = format) + private$draws_ <- csv_contents$draws + private$metadata_ <- csv_contents$metadata + invisible(self) + } + ) +) +CmdStanLaplace$set("public", name = "lp_approx", value = lp_approx) + # CmdStanVB --------------------------------------------------------------- #' CmdStanVB objects @@ -1932,11 +2022,6 @@ CmdStanVB <- R6::R6Class( } ) ) - -#' @rdname fit-method-lp -lp_approx <- function() { - as.numeric(self$draws()[, "lp_approx__"]) -} CmdStanVB$set("public", name = "lp_approx", value = lp_approx) # CmdStanPathfinder --------------------------------------------------------------- @@ -2262,6 +2347,12 @@ as_draws.CmdStanMLE <- function(x, ...) { x$draws(...) } +#' @rdname as_draws.CmdStanMCMC +#' @export +as_draws.CmdStanLaplace <- function(x, ...) { + x$draws(...) +} + #' @rdname as_draws.CmdStanMCMC #' @export as_draws.CmdStanVB <- function(x, ...) { diff --git a/R/model.R b/R/model.R index 8322abc76..9c163f455 100644 --- a/R/model.R +++ b/R/model.R @@ -69,6 +69,9 @@ #' # Use 'posterior' package for summaries #' fit_mcmc$summary() #' +#' # Check sampling diagnostics +#' fit_mcmc$diagnostic_summary() +#' #' # Get posterior draws #' draws <- fit_mcmc$draws() #' print(draws) @@ -79,13 +82,8 @@ #' # Plot posterior using bayesplot (ggplot2) #' mcmc_hist(fit_mcmc$draws("theta")) #' -#' # Call CmdStan's diagnose and stansummary utilities -#' fit_mcmc$cmdstan_diagnose() -#' fit_mcmc$cmdstan_summary() -#' #' # For models fit using MCMC, if you like working with RStan's stanfit objects #' # then you can create one with rstan::read_stan_csv() -#' #' # stanfit <- rstan::read_stan_csv(fit_mcmc$output_files()) #' #' @@ -93,13 +91,16 @@ #' # and also demonstrate specifying data as a path to a file instead of a list #' my_data_file <- file.path(cmdstan_path(), "examples/bernoulli/bernoulli.data.json") #' fit_optim <- mod$optimize(data = my_data_file, seed = 123) -#' #' fit_optim$summary() #' +#' # Run 'optimize' again with 'jacobian=TRUE' and then draw from laplace approximation +#' # to the posterior +#' fit_optim <- mod$optimize(data = my_data_file, jacobian = TRUE) +#' fit_laplace <- mod$laplace(data = my_data_file, mode = fit_optim, draws = 2000) +#' fit_laplace$summary() #' #' # Run 'variational' method to approximate the posterior (default is meanfield ADVI) #' fit_vb <- mod$variational(data = stan_data, seed = 123) -#' #' fit_vb$summary() #' #' # Plot approximate posterior using bayesplot @@ -1447,11 +1448,13 @@ CmdStanModel$set("public", name = "sample_mpi", value = sample_mpi) #' the CmdStan User's Guide. The default values can also be obtained by #' running `cmdstanr_example(method="optimize")$metadata()`. #' @param jacobian (logical) Whether or not to use the Jacobian adjustment for -#' constrained variables. By default this is `FALSE`, meaning optimization +#' constrained variables. For historical reasons, the default is `FALSE`, meaning optimization #' yields the (regularized) maximum likelihood estimate. Setting it to `TRUE` #' yields the maximum a posteriori estimate. See the #' [Maximum Likelihood Estimation](https://mc-stan.org/docs/cmdstan-guide/maximum-likelihood-estimation.html) #' section of the CmdStan User's Guide for more details. +#' For use later with [`$laplace()`][model-method-laplace] the `jacobian` +#' argument should typically be set to `TRUE`. #' @param init_alpha (positive real) The initial step size parameter. #' @param tol_obj (positive real) Convergence tolerance on changes in objective function value. #' @param tol_rel_obj (positive real) Convergence tolerance on relative changes in objective function value. @@ -1534,6 +1537,170 @@ optimize <- function(data = NULL, CmdStanModel$set("public", name = "optimize", value = optimize) +#' Run Stan's laplace algorithm +#' +#' @name model-method-laplace +#' @aliases laplace +#' @family CmdStanModel methods +#' +#' @description The `$laplace()` method of a [`CmdStanModel`] object produces a +#' sample from a normal approximation centered at the mode of a distribution +#' in the unconstrained space. If the mode is a maximum a posteriori (MAP) +#' estimate, the samples provide an estimate of the mean and standard +#' deviation of the posterior distribution. If the mode is a maximum +#' likelihood estimate (MLE), the sample provides an estimate of the standard +#' error of the likelihood. Whether the mode is the MAP or MLE depends on +#' the value of the `jacobian` argument when running optimization. See the +#' [Laplace Sampling](https://mc-stan.org/docs/cmdstan-guide/laplace-sampling.html) +#' section of the CmdStan User's Guide for more details. +#' +#' Any argument left as `NULL` will default to the default value used by the +#' installed version of CmdStan. See the +#' [CmdStan User’s Guide](https://mc-stan.org/docs/cmdstan-guide/) +#' for more details on the default arguments. +#' +#' @template model-common-args +#' @inheritParams model-method-optimize +#' @param save_latent_dynamics Ignored for this method. +#' @param mode (multiple options) The mode to center the approximation at. One +#' of the following: +#' * A [`CmdStanMLE`] object from a previous run of [`$optimize()`][model-method-optimize]. +#' * The path to a CmdStan CSV file from running optimization. +#' * `NULL`, in which case [$optimize()][model-method-optimize] will be run +#' with `jacobian=jacobian` (see the `jacobian` argument below). +#' +#' In all cases the total time reported by [`$time()`][fit-method-time] will be +#' the time of the Laplace sampling step only and does not include the time +#' taken to run the `$optimize()` method. +#' @param opt_args (named list) A named list of optional arguments to pass to +#' [$optimize()][model-method-optimize] if `mode=NULL`. +#' @param draws (positive integer) The number of draws to take. +#' @param jacobian (logical) Whether or not to enable the Jacobian adjustment +#' for constrained parameters. The default is `TRUE`. See the +#' [Laplace Sampling](https://mc-stan.org/docs/cmdstan-guide/laplace-sampling.html) +#' section of the CmdStan User's Guide for more details. If `mode` is not +#' `NULL` then the value of `jacobian` must match the value used when +#' optimization was originally run. If `mode` is `NULL` then the value of +#' `jacobian` specified here is used when running optimization. +#' +#' @return A [`CmdStanLaplace`] object. +#' +#' @template seealso-docs +#' @examples +#' \dontrun{ +#' file <- file.path(cmdstan_path(), "examples/bernoulli/bernoulli.stan") +#' mod <- cmdstan_model(file) +#' mod$print() +#' +#' stan_data <- list(N = 10, y = c(0,1,0,0,0,0,0,0,0,1)) +#' fit_mode <- mod$optimize(data = stan_data, jacobian = TRUE) +#' fit_laplace <- mod$laplace(data = stan_data, mode = fit_mode) +#' fit_laplace$summary() +#' +#' # if mode isn't specified optimize is run internally first +#' fit_laplace <- mod$laplace(data = stan_data) +#' fit_laplace$summary() +#' +#' # plot approximate posterior +#' bayesplot::mcmc_hist(fit_laplace$draws("theta")) +#' } +#' +#' +laplace <- function(data = NULL, + seed = NULL, + refresh = NULL, + init = NULL, + save_latent_dynamics = FALSE, + output_dir = NULL, + output_basename = NULL, + sig_figs = NULL, + threads = NULL, + opencl_ids = NULL, + mode = NULL, + opt_args = NULL, + jacobian = TRUE, # different default than for optimize! + draws = NULL) { + if (cmdstan_version() < "2.32") { + stop("This method is only available in cmdstan >= 2.32", call. = FALSE) + } + if (!is.null(mode) && !is.null(opt_args)) { + stop("Cannot specify both 'opt_args' and 'mode' arguments.", call. = FALSE) + } + procs <- CmdStanProcs$new( + num_procs = 1, + show_stdout_messages = (is.null(refresh) || refresh != 0), + threads_per_proc = assert_valid_threads(threads, self$cpp_options()) + ) + model_variables <- NULL + if (is_variables_method_supported(self)) { + model_variables <- self$variables() + } + + if (!is.null(mode)) { + if (inherits(mode, "CmdStanMLE")) { + cmdstan_mode <- mode + } else { + if (!(is.character(mode) && length(mode) == 1)) { + stop("If not NULL or a CmdStanMLE object then 'mode' must be a path to a CSV file.", call. = FALSE) + } + cmdstan_mode <- as_cmdstan_fit(mode) + } + } else { # mode = NULL, run optimize() + checkmate::assert_list(opt_args, any.missing = FALSE, names = "unique", null.ok = TRUE) + args <- list( + data = data, + seed = seed, + refresh = refresh, + init = init, + save_latent_dynamics = FALSE, + output_dir = output_dir, + output_basename = output_basename, + sig_figs = sig_figs, + threads = threads, + opencl_ids = opencl_ids, + jacobian = jacobian + ) + cmdstan_mode <- do.call(self$optimize, append(args, opt_args)) + if (cmdstan_mode$return_codes() != 0) { + stop( + "Optimization failed.\n", + "Consider supplying the 'mode' argument or additional optimizer args.", + call. = FALSE + ) + } + } + laplace_args <- LaplaceArgs$new( + mode = cmdstan_mode, + draws = draws, + jacobian = jacobian + ) + args <- CmdStanArgs$new( + method_args = laplace_args, + stan_file = self$stan_file(), + stan_code = suppressWarnings(self$code()), + model_methods_env = private$model_methods_env_, + standalone_env = self$functions, + model_name = self$model_name(), + exe_file = self$exe_file(), + proc_ids = 1, + data_file = process_data(data, model_variables), + save_latent_dynamics = FALSE, + seed = seed, + init = init, + refresh = refresh, + output_dir = output_dir, + output_basename = output_basename, + sig_figs = sig_figs, + opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()), + model_variables = model_variables + ) + runset <- CmdStanRun$new(args, procs) + runset$run_cmdstan() + CmdStanLaplace$new(runset) +} +CmdStanModel$set("public", name = "laplace", value = laplace) + + #' Run Stan's variational approximation algorithms #' #' @name model-method-variational @@ -1577,7 +1744,9 @@ CmdStanModel$set("public", name = "optimize", value = optimize) #' @param tol_rel_obj (positive real) Convergence tolerance on the relative norm #' of the objective. #' @param eval_elbo (positive integer) Evaluate ELBO every Nth iteration. -#' @param output_samples (positive integer) Number of approximate posterior +#' @param output_samples (positive integer) Use `draws` argument instead. +#' `output_samples` will be deprecated in the future. +#' @param draws (positive integer) Number of approximate posterior #' samples to draw and save. #' #' @return A [`CmdStanVB`] object. @@ -1604,7 +1773,8 @@ variational <- function(data = NULL, adapt_iter = NULL, tol_rel_obj = NULL, eval_elbo = NULL, - output_samples = NULL) { + output_samples = NULL, + draws = NULL) { procs <- CmdStanProcs$new( num_procs = 1, show_stdout_messages = (is.null(refresh) || refresh != 0), @@ -1624,7 +1794,7 @@ variational <- function(data = NULL, adapt_iter = adapt_iter, tol_rel_obj = tol_rel_obj, eval_elbo = eval_elbo, - output_samples = output_samples + output_samples = draws %||% output_samples ) args <- CmdStanArgs$new( method_args = variational_args, diff --git a/R/run.R b/R/run.R index 752eca095..16a0c97ba 100644 --- a/R/run.R +++ b/R/run.R @@ -229,6 +229,9 @@ CmdStanRun <- R6::R6Class( if (self$method() == "optimize") { stop("Not available for optimize method.", call. = FALSE) } + if (self$method() == "laplace") { + stop("Not available for laplace method.", call. = FALSE) + } if (self$method() == "generate_quantities") { stop("Not available for generate_quantities method.", call. = FALSE) } @@ -259,7 +262,7 @@ CmdStanRun <- R6::R6Class( }, time = function() { - if (self$method() %in% c("optimize", "variational", "pathfinder")) { + if (self$method() %in% c("laplace", "optimize", "variational", "pathfinder")) { time <- list(total = self$procs$total_time()) } else if (self$method() == "generate_quantities") { chain_time <- data.frame( @@ -511,7 +514,7 @@ CmdStanRun$set("private", name = "run_generate_quantities_", value = .run_genera procs$process_output(id) procs$process_error_output(id) successful_fit <- FALSE - if (self$method() == "optimize") { + if (self$method() %in% "optimize") { # QUESTION: should this include laplace? if (procs$proc_state(id = id) > 3) { successful_fit <- TRUE } @@ -532,6 +535,7 @@ CmdStanRun$set("private", name = "run_generate_quantities_", value = .run_genera procs$report_time() } CmdStanRun$set("private", name = "run_optimize_", value = .run_other) +CmdStanRun$set("private", name = "run_laplace_", value = .run_other) CmdStanRun$set("private", name = "run_variational_", value = .run_other) CmdStanRun$set("private", name = "run_pathfinder_", value = .run_other) diff --git a/R/utils.R b/R/utils.R index ec8189411..5f11bd6e0 100644 --- a/R/utils.R +++ b/R/utils.R @@ -600,7 +600,7 @@ wsl_tempdir <- function() { # network path as legitimate, and will always error. To avoid this we create a # new checking functions with WSL handling, and then pass these to # checkmate::makeAssertionFunction to replicate the existing assertion functionality -check_dir_exists <- function(dir, access = NULL) { +check_dir_exists <- function(dir, access = "") { if (os_is_wsl()) { if (!checkmate::qtest(dir, "S+")) { return("No directory provided.") @@ -616,7 +616,7 @@ check_dir_exists <- function(dir, access = NULL) { } } -check_file_exists <- function(files, access = NULL, ...) { +check_file_exists <- function(files, access = "", ...) { if (os_is_wsl()) { if (!checkmate::qtest(files, "S+")) { return("No file provided.") @@ -632,7 +632,7 @@ check_file_exists <- function(files, access = NULL, ...) { } } -.wsl_check_exists <- function(path, is_dir = TRUE, access = NULL) { +.wsl_check_exists <- function(path, is_dir = TRUE, access = "") { path_check <- processx::run( command = "wsl", args = c("ls", "-la", wsl_safe_path(path)), diff --git a/_pkgdown.yml b/_pkgdown.yml index e9133e951..d3786f7d4 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -100,6 +100,7 @@ reference: contents: - CmdStanMCMC - CmdStanMLE + - CmdStanLaplace - CmdStanVB - CmdStanGQ - CmdStanDiagnose diff --git a/man/CmdStanDiagnose.Rd b/man/CmdStanDiagnose.Rd index c9dd6c2b9..70930db17 100644 --- a/man/CmdStanDiagnose.Rd +++ b/man/CmdStanDiagnose.Rd @@ -42,6 +42,7 @@ The Stan and CmdStan documentation: Other fitted model objects: \code{\link{CmdStanGQ}}, +\code{\link{CmdStanLaplace}}, \code{\link{CmdStanMCMC}}, \code{\link{CmdStanMLE}}, \code{\link{CmdStanVB}} diff --git a/man/CmdStanGQ.Rd b/man/CmdStanGQ.Rd index 860820561..6ebe2d9c6 100644 --- a/man/CmdStanGQ.Rd +++ b/man/CmdStanGQ.Rd @@ -103,6 +103,7 @@ The Stan and CmdStan documentation: Other fitted model objects: \code{\link{CmdStanDiagnose}}, +\code{\link{CmdStanLaplace}}, \code{\link{CmdStanMCMC}}, \code{\link{CmdStanMLE}}, \code{\link{CmdStanVB}} diff --git a/man/CmdStanLaplace.Rd b/man/CmdStanLaplace.Rd new file mode 100644 index 000000000..50d782a63 --- /dev/null +++ b/man/CmdStanLaplace.Rd @@ -0,0 +1,72 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/fit.R +\name{CmdStanLaplace} +\alias{CmdStanLaplace} +\title{CmdStanLaplace objects} +\description{ +A \code{CmdStanLaplace} object is the fitted model object returned by the +\code{\link[=model-method-laplace]{$laplace()}} method of a +\code{\link{CmdStanModel}} object. +} +\section{Methods}{ + \code{CmdStanLaplace} objects have the following associated methods, +all of which have their own (linked) documentation pages. +\subsection{Extract contents of fitted model object}{\tabular{ll}{ + \strong{Method} \tab \strong{Description} \cr + \code{\link[=fit-method-draws]{$draws()}} \tab Return approximate posterior draws as a \code{\link[posterior:draws_matrix]{draws_matrix}}. \cr + \verb{$mode()} \tab Return the mode as a \code{\link{CmdStanMLE}} object. \cr + \code{\link[=fit-method-lp]{$lp()}} \tab Return the total log probability density (\code{target}) computed in the model block of the Stan program. \cr + \code{\link[=fit-method-lp]{$lp_approx()}} \tab Return the log density of the approximation to the posterior. \cr + \code{\link[=fit-method-init]{$init()}} \tab Return user-specified initial values. \cr + \code{\link[=fit-method-metadata]{$metadata()}} \tab Return a list of metadata gathered from the CmdStan CSV files. \cr + \code{\link[=fit-method-code]{$code()}} \tab Return Stan code as a character vector. \cr +} + +} + +\subsection{Summarize inferences}{\tabular{ll}{ + \strong{Method} \tab \strong{Description} \cr + \code{\link[=fit-method-summary]{$summary()}} \tab Run \code{\link[posterior:draws_summary]{posterior::summarise_draws()}}. \cr +} + +} + +\subsection{Save fitted model object and temporary files}{\tabular{ll}{ + \strong{Method} \tab \strong{Description} \cr + \code{\link[=fit-method-save_object]{$save_object()}} \tab Save fitted model object to a file. \cr + \code{\link[=fit-method-save_output_files]{$save_output_files()}} \tab Save output CSV files to a specified location. \cr + \code{\link[=fit-method-save_data_file]{$save_data_file()}} \tab Save JSON data file to a specified location. \cr + \code{\link[=fit-method-save_latent_dynamics_files]{$save_latent_dynamics_files()}} \tab Save diagnostic CSV files to a specified location. \cr +} + +} + +\subsection{Report run times, console output, return codes}{\tabular{ll}{ + \strong{Method} \tab \strong{Description} \cr + \code{\link[=fit-method-time]{$time()}} \tab Report the run time of the Laplace sampling step. \cr + \code{\link[=fit-method-output]{$output()}} \tab Pretty print the output that was printed to the console. \cr + \code{\link[=fit-method-return_codes]{$return_codes()}} \tab Return the return codes from the CmdStan runs. \cr +} + +} +} + +\seealso{ +The CmdStanR website +(\href{https://mc-stan.org/cmdstanr/}{mc-stan.org/cmdstanr}) for online +documentation and tutorials. + +The Stan and CmdStan documentation: +\itemize{ +\item Stan documentation: \href{https://mc-stan.org/users/documentation/}{mc-stan.org/users/documentation} +\item CmdStan User’s Guide: \href{https://mc-stan.org/docs/cmdstan-guide/}{mc-stan.org/docs/cmdstan-guide} +} + +Other fitted model objects: +\code{\link{CmdStanDiagnose}}, +\code{\link{CmdStanGQ}}, +\code{\link{CmdStanMCMC}}, +\code{\link{CmdStanMLE}}, +\code{\link{CmdStanVB}} +} +\concept{fitted model objects} diff --git a/man/CmdStanMCMC.Rd b/man/CmdStanMCMC.Rd index 218ee49c3..c74ff41f2 100644 --- a/man/CmdStanMCMC.Rd +++ b/man/CmdStanMCMC.Rd @@ -87,6 +87,7 @@ The Stan and CmdStan documentation: Other fitted model objects: \code{\link{CmdStanDiagnose}}, \code{\link{CmdStanGQ}}, +\code{\link{CmdStanLaplace}}, \code{\link{CmdStanMLE}}, \code{\link{CmdStanVB}} } diff --git a/man/CmdStanMLE.Rd b/man/CmdStanMLE.Rd index 01acae4d9..141e45aa6 100644 --- a/man/CmdStanMLE.Rd +++ b/man/CmdStanMLE.Rd @@ -77,6 +77,7 @@ The Stan and CmdStan documentation: Other fitted model objects: \code{\link{CmdStanDiagnose}}, \code{\link{CmdStanGQ}}, +\code{\link{CmdStanLaplace}}, \code{\link{CmdStanMCMC}}, \code{\link{CmdStanVB}} } diff --git a/man/CmdStanModel.Rd b/man/CmdStanModel.Rd index 2e8176aa7..2644784e2 100644 --- a/man/CmdStanModel.Rd +++ b/man/CmdStanModel.Rd @@ -87,6 +87,9 @@ fit_mcmc <- mod$sample( # Use 'posterior' package for summaries fit_mcmc$summary() +# Check sampling diagnostics +fit_mcmc$diagnostic_summary() + # Get posterior draws draws <- fit_mcmc$draws() print(draws) @@ -97,13 +100,8 @@ as_draws_df(draws) # Plot posterior using bayesplot (ggplot2) mcmc_hist(fit_mcmc$draws("theta")) -# Call CmdStan's diagnose and stansummary utilities -fit_mcmc$cmdstan_diagnose() -fit_mcmc$cmdstan_summary() - # For models fit using MCMC, if you like working with RStan's stanfit objects # then you can create one with rstan::read_stan_csv() - # stanfit <- rstan::read_stan_csv(fit_mcmc$output_files()) @@ -111,13 +109,16 @@ fit_mcmc$cmdstan_summary() # and also demonstrate specifying data as a path to a file instead of a list my_data_file <- file.path(cmdstan_path(), "examples/bernoulli/bernoulli.data.json") fit_optim <- mod$optimize(data = my_data_file, seed = 123) - fit_optim$summary() +# Run 'optimize' again with 'jacobian=TRUE' and then draw from laplace approximation +# to the posterior +fit_optim <- mod$optimize(data = my_data_file, jacobian = TRUE) +fit_laplace <- mod$laplace(data = my_data_file, mode = fit_optim, draws = 2000) +fit_laplace$summary() # Run 'variational' method to approximate the posterior (default is meanfield ADVI) fit_vb <- mod$variational(data = stan_data, seed = 123) - fit_vb$summary() # Plot approximate posterior using bayesplot diff --git a/man/CmdStanVB.Rd b/man/CmdStanVB.Rd index 4b4d53ada..9c8cd80b2 100644 --- a/man/CmdStanVB.Rd +++ b/man/CmdStanVB.Rd @@ -80,6 +80,7 @@ The Stan and CmdStan documentation: Other fitted model objects: \code{\link{CmdStanDiagnose}}, \code{\link{CmdStanGQ}}, +\code{\link{CmdStanLaplace}}, \code{\link{CmdStanMCMC}}, \code{\link{CmdStanMLE}} } diff --git a/man/as_draws.CmdStanMCMC.Rd b/man/as_draws.CmdStanMCMC.Rd index 13ea4f467..a63ae953f 100644 --- a/man/as_draws.CmdStanMCMC.Rd +++ b/man/as_draws.CmdStanMCMC.Rd @@ -4,6 +4,7 @@ \alias{as_draws.CmdStanMCMC} \alias{as_draws} \alias{as_draws.CmdStanMLE} +\alias{as_draws.CmdStanLaplace} \alias{as_draws.CmdStanVB} \alias{as_draws.CmdStanGQ} \title{Create a \code{draws} object from a CmdStanR fitted model object} @@ -12,6 +13,8 @@ \method{as_draws}{CmdStanMLE}(x, ...) +\method{as_draws}{CmdStanLaplace}(x, ...) + \method{as_draws}{CmdStanVB}(x, ...) \method{as_draws}{CmdStanGQ}(x, ...) diff --git a/man/cmdstan_model.Rd b/man/cmdstan_model.Rd index 1f48d1c86..6e9408dd1 100644 --- a/man/cmdstan_model.Rd +++ b/man/cmdstan_model.Rd @@ -74,6 +74,9 @@ fit_mcmc <- mod$sample( # Use 'posterior' package for summaries fit_mcmc$summary() +# Check sampling diagnostics +fit_mcmc$diagnostic_summary() + # Get posterior draws draws <- fit_mcmc$draws() print(draws) @@ -84,13 +87,8 @@ as_draws_df(draws) # Plot posterior using bayesplot (ggplot2) mcmc_hist(fit_mcmc$draws("theta")) -# Call CmdStan's diagnose and stansummary utilities -fit_mcmc$cmdstan_diagnose() -fit_mcmc$cmdstan_summary() - # For models fit using MCMC, if you like working with RStan's stanfit objects # then you can create one with rstan::read_stan_csv() - # stanfit <- rstan::read_stan_csv(fit_mcmc$output_files()) @@ -98,13 +96,16 @@ fit_mcmc$cmdstan_summary() # and also demonstrate specifying data as a path to a file instead of a list my_data_file <- file.path(cmdstan_path(), "examples/bernoulli/bernoulli.data.json") fit_optim <- mod$optimize(data = my_data_file, seed = 123) - fit_optim$summary() +# Run 'optimize' again with 'jacobian=TRUE' and then draw from laplace approximation +# to the posterior +fit_optim <- mod$optimize(data = my_data_file, jacobian = TRUE) +fit_laplace <- mod$laplace(data = my_data_file, mode = fit_optim, draws = 2000) +fit_laplace$summary() # Run 'variational' method to approximate the posterior (default is meanfield ADVI) fit_vb <- mod$variational(data = stan_data, seed = 123) - fit_vb$summary() # Plot approximate posterior using bayesplot diff --git a/man/cmdstanr-package.Rd b/man/cmdstanr-package.Rd index 5a07aec2c..9f2feb180 100644 --- a/man/cmdstanr-package.Rd +++ b/man/cmdstanr-package.Rd @@ -99,6 +99,9 @@ fit_mcmc <- mod$sample( # Use 'posterior' package for summaries fit_mcmc$summary() +# Check sampling diagnostics +fit_mcmc$diagnostic_summary() + # Get posterior draws draws <- fit_mcmc$draws() print(draws) @@ -109,13 +112,8 @@ as_draws_df(draws) # Plot posterior using bayesplot (ggplot2) mcmc_hist(fit_mcmc$draws("theta")) -# Call CmdStan's diagnose and stansummary utilities -fit_mcmc$cmdstan_diagnose() -fit_mcmc$cmdstan_summary() - # For models fit using MCMC, if you like working with RStan's stanfit objects # then you can create one with rstan::read_stan_csv() - # stanfit <- rstan::read_stan_csv(fit_mcmc$output_files()) @@ -123,13 +121,16 @@ fit_mcmc$cmdstan_summary() # and also demonstrate specifying data as a path to a file instead of a list my_data_file <- file.path(cmdstan_path(), "examples/bernoulli/bernoulli.data.json") fit_optim <- mod$optimize(data = my_data_file, seed = 123) - fit_optim$summary() +# Run 'optimize' again with 'jacobian=TRUE' and then draw from laplace approximation +# to the posterior +fit_optim <- mod$optimize(data = my_data_file, jacobian = TRUE) +fit_laplace <- mod$laplace(data = my_data_file, mode = fit_optim, draws = 2000) +fit_laplace$summary() # Run 'variational' method to approximate the posterior (default is meanfield ADVI) fit_vb <- mod$variational(data = stan_data, seed = 123) - fit_vb$summary() # Plot approximate posterior using bayesplot diff --git a/man/cmdstanr_example.Rd b/man/cmdstanr_example.Rd index 2e89993f9..3474811b9 100644 --- a/man/cmdstanr_example.Rd +++ b/man/cmdstanr_example.Rd @@ -7,7 +7,7 @@ \usage{ cmdstanr_example( example = c("logistic", "schools", "schools_ncp"), - method = c("sample", "optimize", "variational", "diagnose"), + method = c("sample", "optimize", "laplace", "variational", "diagnose"), ..., quiet = TRUE, force_recompile = getOption("cmdstanr_force_recompile", default = FALSE) diff --git a/man/fit-method-lp.Rd b/man/fit-method-lp.Rd index 4bf3eedb2..acda71b7c 100644 --- a/man/fit-method-lp.Rd +++ b/man/fit-method-lp.Rd @@ -12,16 +12,17 @@ lp_approx() } \value{ A numeric vector with length equal to the number of (post-warmup) -draws for MCMC and variational inference, and length equal to \code{1} for -optimization. +draws or length equal to \code{1} for optimization. } \description{ The \verb{$lp()} method extracts \code{lp__}, the total log probability (\code{target}) accumulated in the model block of the Stan program. For variational inference the log density of the variational approximation to -the posterior is also available via the \verb{$lp_approx()} method. +the posterior is available via the \verb{$lp_approx()} method. For +Laplace approximation the unnormalized density of the approximation to +the posterior is available via the \verb{$lp_approx()} method. -See the \href{https://mc-stan.org/docs/2_23/reference-manual/sampling-statements-section.html}{Log Probability Increment vs. Sampling Statement} +See the \href{https://mc-stan.org/docs/reference-manual/sampling-statements.html}{Log Probability Increment vs. Sampling Statement} section of the Stan Reference Manual for details on when normalizing constants are dropped from log probability calculations. } @@ -32,10 +33,14 @@ This will in general be different than the unnormalized model log density evaluated at a posterior draw (which is on the constrained space). \code{lp__} is intended to diagnose sampling efficiency and evaluate approximations. -\code{lp_approx__} is the log density of the variational approximation to \code{lp__} -(also on the unconstrained space). It is exposed in the variational method -for performing the checks described in Yao et al. (2018) and implemented in -the \pkg{loo} package. +For variational inference \code{lp_approx__} is the log density of the variational +approximation to \code{lp__} (also on the unconstrained space). It is exposed in +the variational method for performing the checks described in Yao et al. +(2018) and implemented in the \pkg{loo} package. + +For Laplace approximation \code{lp_approx__} is the unnormalized density of the +Laplace approximation. It can be used to perform the same checks as in the +case of the variational method described in Yao et al. (2018). } \examples{ @@ -57,5 +62,5 @@ work?: Evaluating variational inference. \emph{Proceedings of the 35th International Conference on Machine Learning}, PMLR 80:5581–5590. } \seealso{ -\code{\link{CmdStanMCMC}}, \code{\link{CmdStanMLE}}, \code{\link{CmdStanVB}} +\code{\link{CmdStanMCMC}}, \code{\link{CmdStanMLE}}, \code{\link{CmdStanLaplace}}, \code{\link{CmdStanVB}} } diff --git a/man/fit-method-summary.Rd b/man/fit-method-summary.Rd index 89cc06c9a..787feb2ad 100644 --- a/man/fit-method-summary.Rd +++ b/man/fit-method-summary.Rd @@ -75,5 +75,5 @@ fit$print(c("alpha", "beta"), var2 = ~var(as.vector(.x))) } \seealso{ -\code{\link{CmdStanMCMC}}, \code{\link{CmdStanMLE}}, \code{\link{CmdStanVB}}, \code{\link{CmdStanGQ}} +\code{\link{CmdStanMCMC}}, \code{\link{CmdStanMLE}}, \code{\link{CmdStanLaplace}}, \code{\link{CmdStanVB}}, \code{\link{CmdStanGQ}} } diff --git a/man/fit-method-time.Rd b/man/fit-method-time.Rd index 4ed637d9f..38944fd77 100644 --- a/man/fit-method-time.Rd +++ b/man/fit-method-time.Rd @@ -20,18 +20,25 @@ and \code{"total"}. \description{ Report the run time in seconds. For MCMC additional information is provided about the run times of individual chains and the warmup and -sampling phases. +sampling phases. For Laplace approximation the time only include the time +for drawing the approximate sample and does not include the time +taken to run the \verb{$optimize()} method. } \examples{ \dontrun{ fit_mcmc <- cmdstanr_example("logistic", method = "sample") fit_mcmc$time() -fit_mle <- cmdstanr_example("logistic", method = "optimize") -fit_mle$time() - fit_vb <- cmdstanr_example("logistic", method = "variational") fit_vb$time() + +fit_mle <- cmdstanr_example("logistic", method = "optimize", jacobian = TRUE) +fit_mle$time() + +# use fit_mle to draw samples from laplace approximation +fit_laplace <- cmdstanr_example("logistic", method = "laplace", mode = fit_mle) +fit_laplace$time() # just time for drawing sample not for running optimize +fit_laplace$time()$total + fit_mle$time()$total # total time } } diff --git a/man/model-method-check_syntax.Rd b/man/model-method-check_syntax.Rd index 8f9623bf2..64e93c02d 100644 --- a/man/model-method-check_syntax.Rd +++ b/man/model-method-check_syntax.Rd @@ -83,6 +83,7 @@ Other CmdStanModel methods: \code{\link{model-method-expose_functions}}, \code{\link{model-method-format}}, \code{\link{model-method-generate-quantities}}, +\code{\link{model-method-laplace}}, \code{\link{model-method-optimize}}, \code{\link{model-method-sample_mpi}}, \code{\link{model-method-sample}}, diff --git a/man/model-method-compile.Rd b/man/model-method-compile.Rd index 34b8a25de..7e1805747 100644 --- a/man/model-method-compile.Rd +++ b/man/model-method-compile.Rd @@ -154,6 +154,7 @@ Other CmdStanModel methods: \code{\link{model-method-expose_functions}}, \code{\link{model-method-format}}, \code{\link{model-method-generate-quantities}}, +\code{\link{model-method-laplace}}, \code{\link{model-method-optimize}}, \code{\link{model-method-sample_mpi}}, \code{\link{model-method-sample}}, diff --git a/man/model-method-diagnose.Rd b/man/model-method-diagnose.Rd index 7f9cde7d4..15e7e21c4 100644 --- a/man/model-method-diagnose.Rd +++ b/man/model-method-diagnose.Rd @@ -125,6 +125,7 @@ Other CmdStanModel methods: \code{\link{model-method-expose_functions}}, \code{\link{model-method-format}}, \code{\link{model-method-generate-quantities}}, +\code{\link{model-method-laplace}}, \code{\link{model-method-optimize}}, \code{\link{model-method-sample_mpi}}, \code{\link{model-method-sample}}, diff --git a/man/model-method-expose_functions.Rd b/man/model-method-expose_functions.Rd index c01aa3c1a..066c7bad3 100644 --- a/man/model-method-expose_functions.Rd +++ b/man/model-method-expose_functions.Rd @@ -74,6 +74,7 @@ Other CmdStanModel methods: \code{\link{model-method-diagnose}}, \code{\link{model-method-format}}, \code{\link{model-method-generate-quantities}}, +\code{\link{model-method-laplace}}, \code{\link{model-method-optimize}}, \code{\link{model-method-sample_mpi}}, \code{\link{model-method-sample}}, diff --git a/man/model-method-format.Rd b/man/model-method-format.Rd index ad6b4e650..294ad8cce 100644 --- a/man/model-method-format.Rd +++ b/man/model-method-format.Rd @@ -103,6 +103,7 @@ Other CmdStanModel methods: \code{\link{model-method-diagnose}}, \code{\link{model-method-expose_functions}}, \code{\link{model-method-generate-quantities}}, +\code{\link{model-method-laplace}}, \code{\link{model-method-optimize}}, \code{\link{model-method-sample_mpi}}, \code{\link{model-method-sample}}, diff --git a/man/model-method-generate-quantities.Rd b/man/model-method-generate-quantities.Rd index ebe5d1f60..adf2fd146 100644 --- a/man/model-method-generate-quantities.Rd +++ b/man/model-method-generate-quantities.Rd @@ -174,6 +174,7 @@ Other CmdStanModel methods: \code{\link{model-method-diagnose}}, \code{\link{model-method-expose_functions}}, \code{\link{model-method-format}}, +\code{\link{model-method-laplace}}, \code{\link{model-method-optimize}}, \code{\link{model-method-sample_mpi}}, \code{\link{model-method-sample}}, diff --git a/man/model-method-laplace.Rd b/man/model-method-laplace.Rd new file mode 100644 index 000000000..21e0e1b68 --- /dev/null +++ b/man/model-method-laplace.Rd @@ -0,0 +1,207 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/model.R +\name{model-method-laplace} +\alias{model-method-laplace} +\alias{laplace} +\title{Run Stan's laplace algorithm} +\usage{ +laplace( + data = NULL, + seed = NULL, + refresh = NULL, + init = NULL, + save_latent_dynamics = FALSE, + output_dir = NULL, + output_basename = NULL, + sig_figs = NULL, + threads = NULL, + opencl_ids = NULL, + mode = NULL, + opt_args = NULL, + jacobian = TRUE, + draws = NULL +) +} +\arguments{ +\item{data}{(multiple options) The data to use for the variables specified in +the data block of the Stan program. One of the following: +\itemize{ +\item A named list of \R objects with the names corresponding to variables +declared in the data block of the Stan program. Internally this list is then +written to JSON for CmdStan using \code{\link[=write_stan_json]{write_stan_json()}}. See +\code{\link[=write_stan_json]{write_stan_json()}} for details on the conversions performed on \R objects +before they are passed to Stan. +\item A path to a data file compatible with CmdStan (JSON or \R dump). See the +appendices in the CmdStan guide for details on using these formats. +\item \code{NULL} or an empty list if the Stan program has no data block. +}} + +\item{seed}{(positive integer(s)) A seed for the (P)RNG to pass to CmdStan. +In the case of multi-chain sampling the single \code{seed} will automatically be +augmented by the the run (chain) ID so that each chain uses a different +seed. The exception is the transformed data block, which defaults to +using same seed for all chains so that the same data is generated for all +chains if RNG functions are used. The only time \code{seed} should be specified +as a vector (one element per chain) is if RNG functions are used in +transformed data and the goal is to generate \emph{different} data for each +chain.} + +\item{refresh}{(non-negative integer) The number of iterations between +printed screen updates. If \code{refresh = 0}, only error messages will be +printed.} + +\item{init}{(multiple options) The initialization method to use for the +variables declared in the parameters block of the Stan program. One of +the following: +\itemize{ +\item A real number \code{x>0}. This initializes \emph{all} parameters randomly between +\verb{[-x,x]} on the \emph{unconstrained} parameter space.; +\item The number \code{0}. This initializes \emph{all} parameters to \code{0}; +\item A character vector of paths (one per chain) to JSON or Rdump files +containing initial values for all or some parameters. See +\code{\link[=write_stan_json]{write_stan_json()}} to write \R objects to JSON files compatible with +CmdStan. +\item A list of lists containing initial values for all or some parameters. For +MCMC the list should contain a sublist for each chain. For optimization and +variational inference there should be just one sublist. The sublists should +have named elements corresponding to the parameters for which you are +specifying initial values. See \strong{Examples}. +\item A function that returns a single list with names corresponding to the +parameters for which you are specifying initial values. The function can +take no arguments or a single argument \code{chain_id}. For MCMC, if the function +has argument \code{chain_id} it will be supplied with the chain id (from 1 to +number of chains) when called to generate the initial values. See +\strong{Examples}. +}} + +\item{save_latent_dynamics}{Ignored for this method.} + +\item{output_dir}{(string) A path to a directory where CmdStan should write +its output CSV files. For interactive use this can typically be left at +\code{NULL} (temporary directory) since CmdStanR makes the CmdStan output +(posterior draws and diagnostics) available in \R via methods of the fitted +model objects. The behavior of \code{output_dir} is as follows: +\itemize{ +\item If \code{NULL} (the default), then the CSV files are written to a temporary +directory and only saved permanently if the user calls one of the \verb{$save_*} +methods of the fitted model object (e.g., +\code{\link[=fit-method-save_output_files]{$save_output_files()}}). These temporary +files are removed when the fitted model object is +\link[base:gc]{garbage collected} (manually or automatically). +\item If a path, then the files are created in \code{output_dir} with names +corresponding to the defaults used by \verb{$save_output_files()}. +}} + +\item{output_basename}{(string) A string to use as a prefix for the names of +the output CSV files of CmdStan. If \code{NULL} (the default), the basename of +the output CSV files will be comprised from the model name, timestamp, and +5 random characters.} + +\item{sig_figs}{(positive integer) The number of significant figures used +when storing the output values. By default, CmdStan represent the output +values with 6 significant figures. The upper limit for \code{sig_figs} is 18. +Increasing this value will result in larger output CSV files and thus an +increased usage of disk space.} + +\item{threads}{(positive integer) If the model was +\link[=model-method-compile]{compiled} with threading support, the number of +threads to use in parallelized sections (e.g., when +using the Stan functions \code{reduce_sum()} or \code{map_rect()}).} + +\item{opencl_ids}{(integer vector of length 2) The platform and +device IDs of the OpenCL device to use for fitting. The model must +be compiled with \code{cpp_options = list(stan_opencl = TRUE)} for this +argument to have an effect.} + +\item{mode}{(multiple options) The mode to center the approximation at. One +of the following: +\itemize{ +\item A \code{\link{CmdStanMLE}} object from a previous run of \code{\link[=model-method-optimize]{$optimize()}}. +\item The path to a CmdStan CSV file from running optimization. +\item \code{NULL}, in which case \link[=model-method-optimize]{$optimize()} will be run +with \code{jacobian=jacobian} (see the \code{jacobian} argument below). +} + +In all cases the total time reported by \code{\link[=fit-method-time]{$time()}} will be +the time of the Laplace sampling step only and does not include the time +taken to run the \verb{$optimize()} method.} + +\item{opt_args}{(named list) A named list of optional arguments to pass to +\link[=model-method-optimize]{$optimize()} if \code{mode=NULL}.} + +\item{jacobian}{(logical) Whether or not to enable the Jacobian adjustment +for constrained parameters. The default is \code{TRUE}. See the +\href{https://mc-stan.org/docs/cmdstan-guide/laplace-sampling.html}{Laplace Sampling} +section of the CmdStan User's Guide for more details. If \code{mode} is not +\code{NULL} then the value of \code{jacobian} must match the value used when +optimization was originally run. If \code{mode} is \code{NULL} then the value of +\code{jacobian} specified here is used when running optimization.} + +\item{draws}{(positive integer) The number of draws to take.} +} +\value{ +A \code{\link{CmdStanLaplace}} object. +} +\description{ +The \verb{$laplace()} method of a \code{\link{CmdStanModel}} object produces a +sample from a normal approximation centered at the mode of a distribution +in the unconstrained space. If the mode is a maximum a posteriori (MAP) +estimate, the samples provide an estimate of the mean and standard +deviation of the posterior distribution. If the mode is a maximum +likelihood estimate (MLE), the sample provides an estimate of the standard +error of the likelihood. Whether the mode is the MAP or MLE depends on +the value of the \code{jacobian} argument when running optimization. See the +\href{https://mc-stan.org/docs/cmdstan-guide/laplace-sampling.html}{Laplace Sampling} +section of the CmdStan User's Guide for more details. + +Any argument left as \code{NULL} will default to the default value used by the +installed version of CmdStan. See the +\href{https://mc-stan.org/docs/cmdstan-guide/}{CmdStan User’s Guide} +for more details on the default arguments. +} +\examples{ +\dontrun{ +file <- file.path(cmdstan_path(), "examples/bernoulli/bernoulli.stan") +mod <- cmdstan_model(file) +mod$print() + +stan_data <- list(N = 10, y = c(0,1,0,0,0,0,0,0,0,1)) +fit_mode <- mod$optimize(data = stan_data, jacobian = TRUE) +fit_laplace <- mod$laplace(data = stan_data, mode = fit_mode) +fit_laplace$summary() + +# if mode isn't specified optimize is run internally first +fit_laplace <- mod$laplace(data = stan_data) +fit_laplace$summary() + +# plot approximate posterior +bayesplot::mcmc_hist(fit_laplace$draws("theta")) +} + + +} +\seealso{ +The CmdStanR website +(\href{https://mc-stan.org/cmdstanr/}{mc-stan.org/cmdstanr}) for online +documentation and tutorials. + +The Stan and CmdStan documentation: +\itemize{ +\item Stan documentation: \href{https://mc-stan.org/users/documentation/}{mc-stan.org/users/documentation} +\item CmdStan User’s Guide: \href{https://mc-stan.org/docs/cmdstan-guide/}{mc-stan.org/docs/cmdstan-guide} +} + +Other CmdStanModel methods: +\code{\link{model-method-check_syntax}}, +\code{\link{model-method-compile}}, +\code{\link{model-method-diagnose}}, +\code{\link{model-method-expose_functions}}, +\code{\link{model-method-format}}, +\code{\link{model-method-generate-quantities}}, +\code{\link{model-method-optimize}}, +\code{\link{model-method-sample_mpi}}, +\code{\link{model-method-sample}}, +\code{\link{model-method-variables}}, +\code{\link{model-method-variational}} +} +\concept{CmdStanModel methods} diff --git a/man/model-method-optimize.Rd b/man/model-method-optimize.Rd index e2444673b..06535f141 100644 --- a/man/model-method-optimize.Rd +++ b/man/model-method-optimize.Rd @@ -134,11 +134,13 @@ the CmdStan User's Guide. The default values can also be obtained by running \code{cmdstanr_example(method="optimize")$metadata()}.} \item{jacobian}{(logical) Whether or not to use the Jacobian adjustment for -constrained variables. By default this is \code{FALSE}, meaning optimization +constrained variables. For historical reasons, the default is \code{FALSE}, meaning optimization yields the (regularized) maximum likelihood estimate. Setting it to \code{TRUE} yields the maximum a posteriori estimate. See the \href{https://mc-stan.org/docs/cmdstan-guide/maximum-likelihood-estimation.html}{Maximum Likelihood Estimation} -section of the CmdStan User's Guide for more details.} +section of the CmdStan User's Guide for more details. +For use later with \code{\link[=model-method-laplace]{$laplace()}} the \code{jacobian} +argument should typically be set to \code{TRUE}.} \item{init_alpha}{(positive real) The initial step size parameter.} @@ -207,6 +209,9 @@ fit_mcmc <- mod$sample( # Use 'posterior' package for summaries fit_mcmc$summary() +# Check sampling diagnostics +fit_mcmc$diagnostic_summary() + # Get posterior draws draws <- fit_mcmc$draws() print(draws) @@ -217,13 +222,8 @@ as_draws_df(draws) # Plot posterior using bayesplot (ggplot2) mcmc_hist(fit_mcmc$draws("theta")) -# Call CmdStan's diagnose and stansummary utilities -fit_mcmc$cmdstan_diagnose() -fit_mcmc$cmdstan_summary() - # For models fit using MCMC, if you like working with RStan's stanfit objects # then you can create one with rstan::read_stan_csv() - # stanfit <- rstan::read_stan_csv(fit_mcmc$output_files()) @@ -231,13 +231,16 @@ fit_mcmc$cmdstan_summary() # and also demonstrate specifying data as a path to a file instead of a list my_data_file <- file.path(cmdstan_path(), "examples/bernoulli/bernoulli.data.json") fit_optim <- mod$optimize(data = my_data_file, seed = 123) - fit_optim$summary() +# Run 'optimize' again with 'jacobian=TRUE' and then draw from laplace approximation +# to the posterior +fit_optim <- mod$optimize(data = my_data_file, jacobian = TRUE) +fit_laplace <- mod$laplace(data = my_data_file, mode = fit_optim, draws = 2000) +fit_laplace$summary() # Run 'variational' method to approximate the posterior (default is meanfield ADVI) fit_vb <- mod$variational(data = stan_data, seed = 123) - fit_vb$summary() # Plot approximate posterior using bayesplot @@ -304,6 +307,7 @@ Other CmdStanModel methods: \code{\link{model-method-expose_functions}}, \code{\link{model-method-format}}, \code{\link{model-method-generate-quantities}}, +\code{\link{model-method-laplace}}, \code{\link{model-method-sample_mpi}}, \code{\link{model-method-sample}}, \code{\link{model-method-variables}}, diff --git a/man/model-method-sample.Rd b/man/model-method-sample.Rd index e0955abb4..1c415f6a6 100644 --- a/man/model-method-sample.Rd +++ b/man/model-method-sample.Rd @@ -318,6 +318,9 @@ fit_mcmc <- mod$sample( # Use 'posterior' package for summaries fit_mcmc$summary() +# Check sampling diagnostics +fit_mcmc$diagnostic_summary() + # Get posterior draws draws <- fit_mcmc$draws() print(draws) @@ -328,13 +331,8 @@ as_draws_df(draws) # Plot posterior using bayesplot (ggplot2) mcmc_hist(fit_mcmc$draws("theta")) -# Call CmdStan's diagnose and stansummary utilities -fit_mcmc$cmdstan_diagnose() -fit_mcmc$cmdstan_summary() - # For models fit using MCMC, if you like working with RStan's stanfit objects # then you can create one with rstan::read_stan_csv() - # stanfit <- rstan::read_stan_csv(fit_mcmc$output_files()) @@ -342,13 +340,16 @@ fit_mcmc$cmdstan_summary() # and also demonstrate specifying data as a path to a file instead of a list my_data_file <- file.path(cmdstan_path(), "examples/bernoulli/bernoulli.data.json") fit_optim <- mod$optimize(data = my_data_file, seed = 123) - fit_optim$summary() +# Run 'optimize' again with 'jacobian=TRUE' and then draw from laplace approximation +# to the posterior +fit_optim <- mod$optimize(data = my_data_file, jacobian = TRUE) +fit_laplace <- mod$laplace(data = my_data_file, mode = fit_optim, draws = 2000) +fit_laplace$summary() # Run 'variational' method to approximate the posterior (default is meanfield ADVI) fit_vb <- mod$variational(data = stan_data, seed = 123) - fit_vb$summary() # Plot approximate posterior using bayesplot @@ -415,6 +416,7 @@ Other CmdStanModel methods: \code{\link{model-method-expose_functions}}, \code{\link{model-method-format}}, \code{\link{model-method-generate-quantities}}, +\code{\link{model-method-laplace}}, \code{\link{model-method-optimize}}, \code{\link{model-method-sample_mpi}}, \code{\link{model-method-variables}}, diff --git a/man/model-method-sample_mpi.Rd b/man/model-method-sample_mpi.Rd index 89981273f..8c17a4dee 100644 --- a/man/model-method-sample_mpi.Rd +++ b/man/model-method-sample_mpi.Rd @@ -314,6 +314,7 @@ Other CmdStanModel methods: \code{\link{model-method-expose_functions}}, \code{\link{model-method-format}}, \code{\link{model-method-generate-quantities}}, +\code{\link{model-method-laplace}}, \code{\link{model-method-optimize}}, \code{\link{model-method-sample}}, \code{\link{model-method-variables}}, diff --git a/man/model-method-variables.Rd b/man/model-method-variables.Rd index aa609ddda..b5b85f1c2 100644 --- a/man/model-method-variables.Rd +++ b/man/model-method-variables.Rd @@ -43,6 +43,7 @@ Other CmdStanModel methods: \code{\link{model-method-expose_functions}}, \code{\link{model-method-format}}, \code{\link{model-method-generate-quantities}}, +\code{\link{model-method-laplace}}, \code{\link{model-method-optimize}}, \code{\link{model-method-sample_mpi}}, \code{\link{model-method-sample}}, diff --git a/man/model-method-variational.Rd b/man/model-method-variational.Rd index 21c92fe24..b04f8b7c1 100644 --- a/man/model-method-variational.Rd +++ b/man/model-method-variational.Rd @@ -25,7 +25,8 @@ variational( adapt_iter = NULL, tol_rel_obj = NULL, eval_elbo = NULL, - output_samples = NULL + output_samples = NULL, + draws = NULL ) } \arguments{ @@ -151,7 +152,10 @@ of the objective.} \item{eval_elbo}{(positive integer) Evaluate ELBO every Nth iteration.} -\item{output_samples}{(positive integer) Number of approximate posterior +\item{output_samples}{(positive integer) Use \code{draws} argument instead. +\code{output_samples} will be deprecated in the future.} + +\item{draws}{(positive integer) Number of approximate posterior samples to draw and save.} } \value{ @@ -210,6 +214,9 @@ fit_mcmc <- mod$sample( # Use 'posterior' package for summaries fit_mcmc$summary() +# Check sampling diagnostics +fit_mcmc$diagnostic_summary() + # Get posterior draws draws <- fit_mcmc$draws() print(draws) @@ -220,13 +227,8 @@ as_draws_df(draws) # Plot posterior using bayesplot (ggplot2) mcmc_hist(fit_mcmc$draws("theta")) -# Call CmdStan's diagnose and stansummary utilities -fit_mcmc$cmdstan_diagnose() -fit_mcmc$cmdstan_summary() - # For models fit using MCMC, if you like working with RStan's stanfit objects # then you can create one with rstan::read_stan_csv() - # stanfit <- rstan::read_stan_csv(fit_mcmc$output_files()) @@ -234,13 +236,16 @@ fit_mcmc$cmdstan_summary() # and also demonstrate specifying data as a path to a file instead of a list my_data_file <- file.path(cmdstan_path(), "examples/bernoulli/bernoulli.data.json") fit_optim <- mod$optimize(data = my_data_file, seed = 123) - fit_optim$summary() +# Run 'optimize' again with 'jacobian=TRUE' and then draw from laplace approximation +# to the posterior +fit_optim <- mod$optimize(data = my_data_file, jacobian = TRUE) +fit_laplace <- mod$laplace(data = my_data_file, mode = fit_optim, draws = 2000) +fit_laplace$summary() # Run 'variational' method to approximate the posterior (default is meanfield ADVI) fit_vb <- mod$variational(data = stan_data, seed = 123) - fit_vb$summary() # Plot approximate posterior using bayesplot @@ -307,6 +312,7 @@ Other CmdStanModel methods: \code{\link{model-method-expose_functions}}, \code{\link{model-method-format}}, \code{\link{model-method-generate-quantities}}, +\code{\link{model-method-laplace}}, \code{\link{model-method-optimize}}, \code{\link{model-method-sample_mpi}}, \code{\link{model-method-sample}}, diff --git a/man/read_cmdstan_csv.Rd b/man/read_cmdstan_csv.Rd index a986843e9..86c4832df 100644 --- a/man/read_cmdstan_csv.Rd +++ b/man/read_cmdstan_csv.Rd @@ -49,7 +49,7 @@ diagnostic checks be performed after reading in the files? The default is and treedepth.} } \value{ -\code{as_cmdstan_fit()} returns a \link{CmdStanMCMC}, \link{CmdStanMLE}, or +\code{as_cmdstan_fit()} returns a \link{CmdStanMCMC}, \link{CmdStanMLE}, \link{CmdStanLaplace} or \link{CmdStanVB} object. Some methods typically defined for those objects will not work (e.g. \code{save_data_file()}) but the important methods like \verb{$summary()}, \verb{$draws()}, \verb{$sampler_diagnostics()} and others will work fine. @@ -92,7 +92,8 @@ following components: \item \code{point_estimates}: Point estimates for the model parameters. } -For \link[=model-method-variational]{variational inference} the returned list also +For \link[=model-method-laplace]{laplace} and +\link[=model-method-variational]{variational inference} the returned list also includes the following components: \itemize{ \item \code{draws}: A \code{\link[posterior:draws_matrix]{draws_matrix}} (or different format diff --git a/tests/testthat/helper-custom-expectations.R b/tests/testthat/helper-custom-expectations.R index e025670a0..bbf3d0d8e 100644 --- a/tests/testthat/helper-custom-expectations.R +++ b/tests/testthat/helper-custom-expectations.R @@ -64,6 +64,13 @@ expect_optim_output <- function(object) { ) } +expect_laplace_output <- function(object) { + expect_output( + object, + regexp = "Generating draws" + ) +} + expect_vb_output <- function(object) { expect_output( object, diff --git a/tests/testthat/helper-models.R b/tests/testthat/helper-models.R index bed88c0a1..b0773e8b0 100644 --- a/tests/testthat/helper-models.R +++ b/tests/testthat/helper-models.R @@ -18,11 +18,19 @@ testing_model <- function(name) { cmdstan_model(stan_file = testing_stan_file(name)) } -testing_fit <- function(name, method = c("sample", "optimize", "variational", "generate_quantities"), seed = 123, ...) { - method <- match.arg(method) - mod <- testing_model(name) - utils::capture.output( - fit <- mod[[method]](data = testing_data(name), seed = seed, ...) - ) - fit -} +testing_fit <- + function(name, + method = c("sample", + "optimize", + "laplace", + "variational", + "generate_quantities"), + seed = 123, + ...) { + method <- match.arg(method) + mod <- testing_model(name) + utils::capture.output( + fit <- mod[[method]](data = testing_data(name), seed = seed, ...) + ) + fit + } diff --git a/tests/testthat/resources/csv/bernoulli-1-laplace.csv b/tests/testthat/resources/csv/bernoulli-1-laplace.csv new file mode 100644 index 000000000..21a2c93f3 --- /dev/null +++ b/tests/testthat/resources/csv/bernoulli-1-laplace.csv @@ -0,0 +1,1026 @@ +# stan_version_major = 2 +# stan_version_minor = 32 +# stan_version_patch = 2 +# model = bernoulli_model +# start_datetime = 2023-07-31 16:31:26 UTC +# method = laplace +# laplace +# mode = /var/folders/s0/zfzm55px2nd2v__zlw5xfj2h0000gn/T/RtmpxCKExE/bernoulli-202307311031-1-6e33fd.csv +# jacobian = 1 (Default) +# draws = 1000 (Default) +# id = 1 (Default) +# data +# file = /var/folders/s0/zfzm55px2nd2v__zlw5xfj2h0000gn/T/RtmpxCKExE/standata-202227469bc5.json +# init = 2 (Default) +# random +# seed = 123 +# output +# file = /var/folders/s0/zfzm55px2nd2v__zlw5xfj2h0000gn/T/RtmpxCKExE/bernoulli-202307311031-1-979a38.csv +# diagnostic_file = (Default) +# refresh = 100 (Default) +# sig_figs = -1 (Default) +# profile_file = /var/folders/s0/zfzm55px2nd2v__zlw5xfj2h0000gn/T/RtmpxCKExE/bernoulli-profile-202307311031-1-1828a7.csv +# num_threads = 1 (Default) +# stanc_version = stanc3 v2.32.2 +# stancflags = --name=bernoulli_model +log_p__,log_q__,theta +-7.16669,-0.383528,0.374086 +-6.7669,-0.0193024,0.226252 +-7.87754,-1.38747,0.098932 +-8.05328,-1.63333,0.0908309 +-7.66165,-1.09546,0.110525 +-6.7498,-0.00179038,0.242596 +-6.78061,-0.03171,0.282779 +-6.95711,-0.226279,0.175507 +-8.00475,-1.09855,0.472419 +-7.05561,-0.339152,0.161424 +-7.17748,-0.393042,0.375772 +-7.58062,-0.741659,0.428821 +-8.09523,-1.69306,0.0890438 +-6.75524,-0.00711871,0.265211 +-7.3554,-0.54838,0.401208 +-6.89127,-0.135625,0.320517 +-6.86483,-0.111137,0.313394 +-7.53461,-0.928912,0.11844 +-7.25703,-0.579411,0.139883 +-6.889,-0.150268,0.187844 +-7.55452,-0.719424,0.42582 +-6.83285,-0.0890597,0.201014 +-7.90476,-1.42508,0.0976071 +-8.65984,-1.63837,0.527018 +-7.43018,-0.612918,0.410841 +-6.93293,-0.173899,0.330606 +-6.92726,-0.192753,0.180564 +-7.00197,-0.236619,0.345248 +-7.53642,-0.703976,0.42371 +-7.14386,-0.442941,0.151089 +-6.7488,-0.000771829,0.254944 +-6.80456,-0.0545645,0.293514 +-6.76658,-0.0181793,0.274584 +-6.75519,-0.00707134,0.265159 +-6.81197,-0.0666928,0.207167 +-7.0906,-0.380012,0.157122 +-6.75871,-0.0108709,0.232026 +-7.41545,-0.776401,0.126823 +-7.19904,-0.509023,0.145386 +-6.76646,-0.0188509,0.226522 +-6.93036,-0.171549,0.330016 +-7.18281,-0.39774,0.376599 +-6.74874,-0.000718872,0.254771 +-6.7876,-0.0408973,0.215976 +-6.76237,-0.0140874,0.271563 +-7.6106,-0.767164,0.432215 +-6.7618,-0.0135316,0.271123 +-7.47944,-0.857851,0.122192 +-8.97548,-1.8954,0.549683 +-7.50937,-0.896315,0.12013 +-7.37605,-0.566243,0.403921 +-6.96161,-0.231366,0.174783 +-6.84033,-0.0971273,0.199019 +-7.53953,-0.706635,0.424075 +-6.9767,-0.248462,0.172422 +-7.01971,-0.297637,0.166177 +-6.74873,-0.000701807,0.254713 +-7.41007,-0.595605,0.408302 +-7.23376,-0.442483,0.384272 +-8.08504,-1.67851,0.0894731 +-6.85412,-0.101159,0.310295 +-6.75198,-0.00391608,0.261226 +-6.75569,-0.00756515,0.26569 +-7.28374,-0.486142,0.391444 +-7.5764,-0.983259,0.115731 +-8.7977,-1.75082,0.537157 +-7.48434,-0.659426,0.417511 +-6.78289,-0.0359477,0.218 +-6.90537,-0.148623,0.324068 +-7.53121,-0.924511,0.118665 +-7.46921,-0.844758,0.122911 +-7.00738,-0.241502,0.346313 +-7.12055,-0.415296,0.153659 +-6.77965,-0.0307829,0.282278 +-6.9983,-0.233311,0.344521 +-6.91179,-0.175507,0.183385 +-6.82368,-0.0726385,0.300587 +-6.76327,-0.0149604,0.27224 +-7.91397,-1.43783,0.0971654 +-6.83795,-0.0945629,0.199642 +-7.23772,-0.555868,0.141665 +-8.13919,-1.75611,0.087225 +-6.96739,-0.205311,0.338181 +-7.80248,-0.929304,0.452708 +-6.77448,-0.0271721,0.222005 +-7.3228,-0.520121,0.396838 +-6.7862,-0.0394239,0.216564 +-9.91191,-2.652,0.607485 +-7.47478,-0.65123,0.416351 +-6.75179,-0.00373055,0.260953 +-6.89605,-0.158054,0.186422 +-6.75191,-0.00385486,0.261136 +-6.74824,-0.000218131,0.247399 +-7.25091,-0.571938,0.140442 +-6.98238,-0.218908,0.341304 +-7.87302,-0.988505,0.459779 +-7.06168,-0.346208,0.160657 +-6.76217,-0.0144238,0.229376 +-6.91403,-0.177995,0.182967 +-7.48189,-0.860998,0.12202 +-6.82442,-0.0800045,0.203383 +-7.05516,-0.33863,0.161482 +-6.75489,-0.00678426,0.264842 +-6.75219,-0.00421688,0.238698 +-6.75666,-0.00877391,0.23381 +-8.18254,-1.81872,0.0854831 +-8.1402,-1.21109,0.484742 +-6.76222,-0.01448,0.229337 +-7.61297,-0.769174,0.43248 +-7.36186,-0.553973,0.402062 +-6.75183,-0.00376951,0.261011 +-8.54777,-1.54672,0.518476 +-6.80693,-0.0568099,0.294445 +-6.7572,-0.00931939,0.233326 +-6.74821,-0.000185088,0.252414 +-6.74821,-0.000190997,0.252452 +-6.78203,-0.0330753,0.283505 +-7.08116,-0.368955,0.158253 +-6.84219,-0.0991379,0.198537 +-6.81748,-0.0667908,0.298396 +-7.38856,-0.577054,0.405545 +-6.75534,-0.00742506,0.23508 +-6.74885,-0.000826753,0.255118 +-7.64143,-1.06869,0.111724 +-6.9617,-0.200141,0.336971 +-8.30331,-1.34589,0.498792 +-7.6539,-0.803906,0.437017 +-6.75217,-0.0041048,0.261497 +-6.92635,-0.191734,0.180726 +-7.24715,-0.567352,0.140789 +-7.25837,-0.46401,0.387844 +-6.75015,-0.00210888,0.258207 +-7.0363,-0.267526,0.35184 +-7.12074,-0.342882,0.366669 +-6.78828,-0.041607,0.215697 +-6.91469,-0.157187,0.326334 +-6.78416,-0.0372774,0.217442 +-7.55967,-0.961455,0.116802 +-8.00199,-1.56083,0.0930892 +-6.86527,-0.111548,0.313519 +-6.87028,-0.129712,0.19183 +-12.2616,-4.54887,0.713454 +-6.7712,-0.0226505,0.277537 +-6.84358,-0.100646,0.19818 +-7.99036,-1.08656,0.471072 +-6.74846,-0.000433075,0.253698 +-8.11825,-1.19289,0.482787 +-7.09996,-0.324419,0.363176 +-6.75428,-0.00633409,0.236197 +-6.78764,-0.0409331,0.215962 +-6.76511,-0.0167484,0.273568 +-6.91617,-0.180374,0.182571 +-6.76755,-0.0191187,0.275231 +-7.73308,-1.19083,0.106457 +-7.68887,-0.833514,0.440816 +-7.48711,-0.867691,0.121657 +-6.75562,-0.0074987,0.265619 +-7.11886,-0.413297,0.15385 +-7.55829,-0.722635,0.426255 +-8.47846,-2.25747,0.074802 +-6.8271,-0.0758555,0.301758 +-7.03017,-0.309693,0.164753 +-6.75711,-0.00923077,0.233404 +-6.74808,-5.515e-05,0.248691 +-6.82563,-0.0812998,0.203035 +-6.75542,-0.00750291,0.235003 +-7.63703,-1.06286,0.111988 +-6.77387,-0.0252225,0.279111 +-6.77252,-0.0239186,0.278323 +-6.76132,-0.0130638,0.270745 +-7.01839,-0.296112,0.16636 +-7.68159,-1.12195,0.109365 +-6.92574,-0.191056,0.180834 +-7.2046,-0.51573,0.144838 +-6.75415,-0.00620822,0.236332 +-7.31516,-0.513483,0.395796 +-7.43835,-0.619952,0.411863 +-8.03585,-1.60862,0.0915893 +-6.77167,-0.0231034,0.27782 +-6.84357,-0.0913115,0.3071 +-6.99478,-0.269062,0.169714 +-7.27835,-0.605526,0.137969 +-7.34283,-0.537495,0.399537 +-7.05595,-0.285147,0.355453 +-7.32879,-0.525318,0.397649 +-6.74828,-0.000260436,0.252865 +-7.1079,-0.400358,0.155099 +-7.96087,-1.50315,0.0949605 +-6.76918,-0.0206924,0.276281 +-6.75197,-0.00398559,0.239008 +-7.15165,-0.452219,0.150253 +-7.3641,-0.555911,0.402357 +-6.85909,-0.117489,0.194384 +-6.79146,-0.0421029,0.287993 +-6.74831,-0.0002907,0.253027 +-6.95704,-0.226205,0.175518 +-6.79312,-0.0436786,0.288729 +-7.31377,-0.649196,0.134907 +-6.95219,-0.191481,0.334913 +-6.78673,-0.0399761,0.216342 +-7.24819,-0.45511,0.386376 +-6.81503,-0.0699538,0.206203 +-8.31493,-2.01256,0.0804596 +-6.97815,-0.250108,0.1722 +-6.77982,-0.0327382,0.219395 +-6.84421,-0.0919124,0.307299 +-6.76467,-0.0163241,0.273259 +-6.92183,-0.18668,0.181538 +-7.57678,-0.983753,0.115707 +-6.88544,-0.130242,0.319003 +-7.1944,-0.407943,0.37838 +-7.71617,-1.16814,0.107398 +-7.23985,-0.558464,0.141465 +-6.8009,-0.0510907,0.292039 +-7.46185,-0.835361,0.123433 +-6.74897,-0.000954477,0.244579 +-6.82336,-0.078861,0.203693 +-7.79933,-0.926656,0.452387 +-8.42745,-2.18043,0.0765069 +-8.12327,-1.19706,0.483236 +-6.87952,-0.139837,0.189822 +-6.82765,-0.0834644,0.202461 +-6.97735,-0.249204,0.172322 +-8.28196,-1.9639,0.0816714 +-6.76284,-0.0151122,0.228904 +-6.74859,-0.000570701,0.254248 +-8.74831,-2.6747,0.0665755 +-7.66847,-1.10451,0.110126 +-6.77569,-0.0284308,0.221391 +-6.78007,-0.0330037,0.219277 +-6.79958,-0.0498338,0.291494 +-6.83185,-0.0879861,0.201288 +-8.0416,-1.12922,0.475836 +-10.7652,-3.33883,0.651159 +-6.98958,-0.22542,0.342769 +-6.80289,-0.0570379,0.210191 +-7.3188,-0.516648,0.396293 +-7.03166,-0.263363,0.350972 +-6.81364,-0.0684733,0.206637 +-7.49289,-0.666754,0.418543 +-6.75029,-0.00228844,0.24164 +-7.88282,-1.39475,0.0986728 +-6.74847,-0.000442915,0.25374 +-8.01181,-1.57466,0.0926504 +-6.7746,-0.0272914,0.221946 +-6.869,-0.128304,0.192116 +-6.79697,-0.050768,0.212313 +-6.96632,-0.236695,0.174035 +-6.93595,-0.176664,0.331296 +-6.84661,-0.0941474,0.308035 +-11.0747,-3.58812,0.665364 +-6.74841,-0.000388626,0.253502 +-7.05466,-0.338044,0.161546 +-7.05993,-0.344176,0.160877 +-6.78572,-0.0389231,0.216766 +-7.64173,-1.06908,0.111706 +-6.76597,-0.0183383,0.226833 +-6.79367,-0.0442042,0.288972 +-7.26819,-0.472585,0.389248 +-6.92094,-0.185697,0.181698 +-6.82656,-0.0753505,0.301576 +-6.78446,-0.0354076,0.284714 +-7.72221,-0.861681,0.444374 +-6.98611,-0.259172,0.170997 +-6.84453,-0.092207,0.307397 +-7.24583,-0.565741,0.140911 +-7.07516,-0.361941,0.158983 +-7.02214,-0.300428,0.165844 +-6.75189,-0.00390548,0.239117 +-7.12847,-0.424663,0.152775 +-6.76685,-0.0184327,0.27476 +-6.8027,-0.056835,0.210258 +-8.40224,-2.14257,0.0773691 +-6.91294,-0.176785,0.18317 +-6.88373,-0.128658,0.318553 +-6.86718,-0.113317,0.314054 +-6.78805,-0.0388439,0.286429 +-6.78128,-0.0342634,0.218723 +-6.85365,-0.11157,0.195679 +-7.04352,-0.325127,0.162983 +-7.13017,-0.351244,0.368224 +-6.80259,-0.0567199,0.210295 +-6.76193,-0.0141814,0.229545 +-6.79096,-0.0444235,0.214616 +-7.09002,-0.315572,0.361473 +-9.11189,-3.26238,0.0572425 +-7.00966,-0.286091,0.167578 +-7.20751,-0.419458,0.380367 +-7.01583,-0.249121,0.347957 +-7.29513,-0.49606,0.393035 +-7.65645,-0.806066,0.437296 +-7.77415,-1.24623,0.104228 +-6.75027,-0.00223528,0.258452 +-6.82336,-0.0723346,0.300475 +-7.24853,-0.569037,0.140661 +-6.86544,-0.124417,0.192918 +-6.81267,-0.0622519,0.296635 +-6.75252,-0.00445351,0.261983 +-8.7934,-1.74731,0.536846 +-7.34711,-0.6906,0.132148 +-7.46379,-0.837836,0.123295 +-6.84889,-0.106396,0.196844 +-7.44875,-0.62889,0.413155 +-6.86276,-0.109211,0.312805 +-6.95375,-0.192899,0.335252 +-7.15541,-0.456693,0.149854 +-6.83251,-0.0886938,0.201107 +-6.7712,-0.0237637,0.22375 +-6.88352,-0.128462,0.318497 +-10.3085,-5.40172,0.0359213 +-6.75317,-0.00521177,0.237457 +-6.80984,-0.0644161,0.207857 +-7.10956,-0.332956,0.364801 +-6.97301,-0.210411,0.339363 +-6.81818,-0.0733196,0.205234 +-6.94473,-0.212334,0.177548 +-6.82422,-0.0731479,0.300774 +-6.90337,-0.166161,0.184986 +-6.89561,-0.139634,0.321628 +-6.91719,-0.159482,0.326932 +-6.78382,-0.0347927,0.284399 +-7.1617,-0.379126,0.373299 +-6.76526,-0.0176075,0.227284 +-7.09598,-0.386331,0.156485 +-6.75581,-0.00789823,0.234622 +-6.75637,-0.00823304,0.266381 +-7.02424,-0.30285,0.165557 +-6.88537,-0.146268,0.188592 +-7.1169,-0.410973,0.154072 +-6.81678,-0.0718229,0.205662 +-6.75847,-0.0106176,0.232231 +-6.75381,-0.0057197,0.263607 +-7.50378,-0.676073,0.419848 +-6.88467,-0.129529,0.318801 +-6.8444,-0.092086,0.307357 +-6.93518,-0.201608,0.179177 +-7.13746,-0.435327,0.151786 +-8.27136,-1.31954,0.496102 +-6.74841,-0.000382312,0.253473 +-6.74835,-0.000327133,0.253212 +-7.26143,-0.466683,0.388283 +-6.82369,-0.0792217,0.203595 +-7.44545,-0.626053,0.412746 +-7.43404,-0.799965,0.125443 +-6.76387,-0.0155488,0.272685 +-6.75471,-0.00677949,0.23573 +-8.42441,-1.44554,0.508734 +-6.85671,-0.114895,0.194947 +-8.76343,-1.72289,0.534673 +-7.66319,-1.0975,0.110435 +-6.85067,-0.0979458,0.309268 +-7.25012,-0.456796,0.386655 +-6.83079,-0.0868396,0.201582 +-6.81698,-0.072034,0.205601 +-6.76614,-0.0177463,0.274281 +-6.91965,-0.161738,0.327516 +-6.97865,-0.250677,0.172124 +-7.86062,-0.978113,0.458552 +-6.82398,-0.0795313,0.203511 +-6.79884,-0.0527412,0.21163 +-7.08127,-0.307772,0.359954 +-9.83959,-2.59376,0.603435 +-7.87129,-1.37885,0.0992402 +-7.0223,-0.300617,0.165822 +-7.67849,-0.824733,0.439696 +-6.74834,-0.000315332,0.253153 +-6.87183,-0.131408,0.191487 +-7.5155,-0.904215,0.119716 +-8.70854,-2.61218,0.0677108 +-6.74836,-0.00034119,0.24675 +-7.13433,-0.35493,0.368904 +-6.7733,-0.0259426,0.22262 +-6.95155,-0.220015,0.176413 +-6.83074,-0.0867954,0.201593 +-7.06453,-0.292816,0.356995 +-7.31922,-0.655943,0.134448 +-6.74961,-0.00159402,0.24301 +-7.72667,-1.18222,0.106812 +-7.76417,-1.23272,0.104763 +-6.7595,-0.0112856,0.269246 +-7.03054,-0.31011,0.164704 +-6.78363,-0.0346036,0.284301 +-7.80972,-1.29454,0.102356 +-6.7481,-7.90431e-05,0.251576 +-7.02959,-0.2615,0.350582 +-7.24526,-0.565045,0.140963 +-8.13387,-1.74846,0.087442 +-6.74889,-0.000863834,0.255233 +-6.97702,-0.248824,0.172373 +-7.44263,-0.810867,0.124817 +-7.39216,-0.580154,0.406009 +-6.76683,-0.01923,0.226296 +-6.83787,-0.0944697,0.199665 +-6.75199,-0.00400831,0.238977 +-7.305,-0.504655,0.394402 +-7.38951,-0.577869,0.405667 +-7.2993,-0.631313,0.136141 +-6.74803,-1.07605e-05,0.250581 +-6.93465,-0.201017,0.179268 +-6.97211,-0.243251,0.17313 +-6.74835,-0.000324569,0.24683 +-6.77246,-0.0250718,0.223065 +-6.78379,-0.0368942,0.217601 +-7.70609,-1.15464,0.107965 +-7.87445,-0.989702,0.45992 +-7.118,-0.412277,0.153947 +-7.03287,-0.264445,0.351198 +-8.55783,-1.55495,0.519254 +-6.89884,-0.142613,0.322444 +-7.55236,-0.951942,0.117276 +-8.69527,-1.6673,0.529661 +-7.09272,-0.382498,0.156871 +-6.76676,-0.0191635,0.226335 +-8.76392,-2.69933,0.0661368 +-8.34031,-1.37638,0.501872 +-7.53165,-0.699904,0.423151 +-7.17761,-0.393162,0.375793 +-6.74999,-0.0019811,0.242216 +-7.13913,-0.437318,0.151603 +-6.76511,-0.0167461,0.273567 +-7.29906,-0.631017,0.136162 +-8.10133,-1.17885,0.48127 +-7.05392,-0.283326,0.355084 +-7.0006,-0.275709,0.16887 +-6.87564,-0.121169,0.316388 +-7.09237,-0.382087,0.156912 +-7.20718,-0.419169,0.380317 +-6.74887,-0.000845771,0.255177 +-6.80981,-0.0643852,0.207866 +-7.29637,-0.49714,0.393207 +-6.81634,-0.0657206,0.297986 +-9.47501,-2.29978,0.582042 +-10.2558,-2.92883,0.625954 +-7.1876,-0.40196,0.377338 +-6.95743,-0.226645,0.175455 +-6.88205,-0.14262,0.189286 +-6.81991,-0.0690902,0.299267 +-6.97419,-0.245617,0.172807 +-6.74981,-0.00180403,0.242568 +-6.82172,-0.0707975,0.299906 +-7.3003,-0.632542,0.136055 +-6.74852,-0.000496953,0.253962 +-7.21734,-0.428088,0.381841 +-6.96553,-0.235801,0.17416 +-7.60331,-1.01848,0.114044 +-6.7628,-0.0150713,0.228932 +-6.8314,-0.0874972,0.201413 +-6.77248,-0.0250884,0.223057 +-7.02006,-0.252926,0.34877 +-6.9047,-0.167631,0.184731 +-7.46096,-0.639371,0.414661 +-6.88756,-0.148682,0.188139 +-7.20455,-0.416864,0.379921 +-6.75271,-0.00473905,0.23803 +-8.0289,-1.59878,0.0918944 +-7.197,-0.410225,0.378775 +-7.14914,-0.368031,0.371299 +-6.75119,-0.00313665,0.260032 +-7.35562,-0.548572,0.401238 +-7.70355,-1.15124,0.108109 +-6.97134,-0.242384,0.173249 +-7.07263,-0.358981,0.159294 +-7.39221,-0.747082,0.128587 +-7.23397,-0.442663,0.384302 +-7.13142,-0.352348,0.368428 +-6.90738,-0.150468,0.324561 +-7.84466,-1.34227,0.10057 +-7.78564,-1.2618,0.103618 +-6.75014,-0.00213358,0.241925 +-6.77728,-0.030089,0.220604 +-6.75026,-0.00225377,0.241703 +-6.76486,-0.0165046,0.273391 +-9.1504,-2.03728,0.561461 +-8.02026,-1.58658,0.0922756 +-9.87602,-2.6231,0.605483 +-7.98479,-1.53665,0.0938654 +-6.76076,-0.0125152,0.270294 +-6.86257,-0.109033,0.312751 +-6.7765,-0.0292738,0.220988 +-8.66475,-2.54377,0.0689899 +-7.46822,-0.843492,0.122981 +-6.97801,-0.249956,0.172221 +-8.37584,-1.40562,0.504794 +-7.29735,-0.497998,0.393344 +-6.77538,-0.0281078,0.221547 +-6.97904,-0.251118,0.172065 +-10.5892,-3.19708,0.642715 +-7.04738,-0.329601,0.162481 +-7.18915,-0.497108,0.146372 +-6.92844,-0.16979,0.329573 +-6.74972,-0.00168419,0.257326 +-6.83747,-0.0940421,0.19977 +-8.09564,-1.69366,0.0890263 +-7.19047,-0.404484,0.377778 +-6.87494,-0.134809,0.190807 +-6.7502,-0.00216177,0.25831 +-6.94471,-0.184658,0.333263 +-7.15521,-0.373397,0.37227 +-6.74808,-5.72141e-05,0.248666 +-7.60786,-1.02446,0.113762 +-8.00154,-1.09588,0.47212 +-6.74847,-0.000445481,0.246288 +-6.8992,-0.161545,0.185798 +-7.31199,-0.646992,0.135058 +-6.74977,-0.00176091,0.242656 +-9.13015,-2.02088,0.560124 +-8.17798,-1.24238,0.488069 +-6.91651,-0.180759,0.182508 +-6.82478,-0.080384,0.203281 +-6.77798,-0.0308129,0.220268 +-8.33585,-1.3727,0.501502 +-7.18435,-0.49134,0.146856 +-6.75411,-0.00601239,0.263957 +-6.93271,-0.19885,0.179605 +-6.75325,-0.00528473,0.237371 +-6.77034,-0.0218111,0.277005 +-6.88422,-0.129113,0.318682 +-6.74803,-8.64198e-06,0.249482 +-6.81037,-0.0600686,0.295767 +-6.78114,-0.0322185,0.283051 +-6.80411,-0.0541379,0.293336 +-6.94353,-0.183588,0.333002 +-8.49445,-1.50302,0.51431 +-6.83129,-0.0873835,0.201442 +-7.68811,-1.13063,0.108989 +-7.10337,-0.395022,0.155622 +-7.37412,-0.72437,0.129992 +-6.79021,-0.0409091,0.287427 +-6.86298,-0.109417,0.312869 +-7.36442,-0.556192,0.402399 +-6.75422,-0.00612099,0.264085 +-6.82468,-0.073584,0.300933 +-7.06818,-0.296079,0.357646 +-8.69662,-1.66841,0.529762 +-6.74897,-0.000942107,0.255466 +-6.74884,-0.00082722,0.244951 +-6.74817,-0.000145258,0.252138 +-8.5384,-1.53904,0.517748 +-6.98391,-0.256659,0.171327 +-6.80279,-0.0528879,0.292808 +-6.91879,-0.183292,0.182091 +-6.81146,-0.0611068,0.296182 +-6.88968,-0.134158,0.320107 +-7.05234,-0.281908,0.354796 +-7.07729,-0.364427,0.158723 +-6.98796,-0.261273,0.170722 +-6.84679,-0.104127,0.197366 +-7.64599,-1.07471,0.111452 +-7.71681,-1.169,0.107362 +-8.57091,-1.56566,0.520262 +-7.10396,-0.395715,0.155554 +-8.55159,-2.36893,0.0724458 +-8.92289,-2.95324,0.0618719 +-7.21659,-0.427435,0.38173 +-6.7835,-0.0365888,0.217729 +-7.08896,-0.378086,0.157317 +-6.82615,-0.0818527,0.202888 +-6.75063,-0.00263244,0.241041 +-7.16775,-0.38447,0.374253 +-7.2303,-0.54685,0.142362 +-7.49807,-0.881771,0.120901 +-6.75182,-0.0037651,0.261004 +-7.42745,-0.791593,0.12593 +-7.04342,-0.325012,0.162996 +-7.48716,-0.661843,0.417852 +-8.29708,-1.34076,0.49827 +-7.32394,-0.661785,0.134054 +-6.83439,-0.08271,0.304182 +-6.89678,-0.140713,0.321924 +-7.84314,-1.34019,0.100646 +-6.83581,-0.0840394,0.304642 +-6.94937,-0.188912,0.334295 +-7.23338,-0.550591,0.142072 +-7.03817,-0.269209,0.35219 +-7.38635,-0.575139,0.405259 +-6.80827,-0.0580868,0.294967 +-7.32288,-0.52019,0.396848 +-7.66225,-0.810982,0.43793 +-8.00595,-1.5664,0.0929119 +-6.88057,-0.125735,0.317714 +-6.99718,-0.232297,0.344298 +-6.88111,-0.126234,0.317858 +-7.43011,-0.794975,0.125733 +-7.54382,-0.940857,0.117833 +-6.94673,-0.186508,0.333713 +-6.80883,-0.0633473,0.208185 +-6.76163,-0.0133649,0.270989 +-7.0686,-0.354282,0.159791 +-7.86078,-0.978251,0.458568 +-6.83443,-0.082746,0.304195 +-7.04761,-0.277675,0.353933 +-6.75103,-0.00298621,0.259785 +-6.84926,-0.106798,0.196753 +-7.96826,-1.06813,0.468987 +-13.8963,-5.89163,0.766712 +-7.07207,-0.358324,0.159363 +-7.09567,-0.385968,0.156522 +-6.80766,-0.0621016,0.208572 +-7.06426,-0.292579,0.356948 +-7.46598,-0.840636,0.123139 +-7.86374,-1.36847,0.0996142 +-6.77665,-0.029429,0.220915 +-8.88017,-1.81795,0.543041 +-6.74823,-0.000209224,0.247453 +-6.77255,-0.0239483,0.278341 +-7.45648,-0.635531,0.41411 +-6.99394,-0.22937,0.34365 +-6.74804,-1.90033e-05,0.249231 +-6.90492,-0.167873,0.184689 +-8.21591,-1.27375,0.491364 +-6.78836,-0.0416878,0.215666 +-6.76765,-0.0200821,0.225794 +-7.26964,-0.594836,0.138745 +-6.7992,-0.0494783,0.291339 +-6.7497,-0.001686,0.242813 +-6.84058,-0.0974044,0.198952 +-6.79785,-0.0481875,0.29077 +-7.43702,-0.618808,0.411697 +-8.32761,-1.36592,0.500819 +-11.3164,-7.4482,0.0248033 +-7.41346,-0.773892,0.126971 +-6.84172,-0.0895826,0.306524 +-7.28858,-0.490356,0.392122 +-6.89863,-0.160911,0.185911 +-7.14656,-0.365747,0.370884 +-6.8306,-0.0866397,0.201633 +-6.79268,-0.0432608,0.288535 +-6.7535,-0.00541075,0.263228 +-6.84275,-0.0905425,0.306844 +-7.70541,-0.847494,0.442588 +-6.78493,-0.0380847,0.217108 +-6.84814,-0.105586,0.19703 +-7.71112,-1.16138,0.107681 +-6.79377,-0.0473896,0.213518 +-7.24178,-0.560802,0.141287 +-6.84943,-0.106982,0.196711 +-6.75638,-0.00824333,0.266392 +-6.84838,-0.105842,0.196971 +-7.289,-0.618623,0.137034 +-7.15955,-0.461642,0.149416 +-7.28669,-0.488713,0.391858 +-8.97986,-3.0456,0.0604286 +-7.16444,-0.467484,0.148903 +-7.08823,-0.377233,0.157404 +-8.08164,-1.16251,0.479493 +-6.80356,-0.0577442,0.20996 +-9.05956,-1.96365,0.55541 +-7.47604,-0.652309,0.416504 +-6.98761,-0.260882,0.170773 +-7.07264,-0.358989,0.159293 +-7.45651,-0.635557,0.414114 +-7.16359,-0.466461,0.148992 +-8.27982,-1.96076,0.0817508 +-6.81947,-0.068674,0.29911 +-8.57958,-2.41191,0.0715701 +-6.74803,-4.26916e-06,0.249636 +-8.5764,-1.57015,0.520685 +-8.11834,-1.19296,0.482795 +-8.74212,-1.70552,0.533116 +-6.85378,-0.111705,0.195649 +-7.11557,-0.409412,0.154222 +-6.76006,-0.0122515,0.230949 +-6.8611,-0.119676,0.193916 +-6.82149,-0.0768604,0.204242 +-7.45113,-0.630932,0.41345 +-6.74866,-0.000638611,0.24556 +-7.17201,-0.388222,0.37492 +-6.75017,-0.00213278,0.258253 +-6.96787,-0.238445,0.173792 +-6.74991,-0.00187688,0.257737 +-6.75916,-0.0109626,0.268963 +-6.96188,-0.23167,0.17474 +-7.0143,-0.291412,0.166928 +-7.26258,-0.5862,0.139379 +-7.00583,-0.281702,0.16812 +-7.06747,-0.352963,0.159932 +-10.7982,-3.36537,0.652709 +-6.74887,-0.000845882,0.255178 +-6.91828,-0.160485,0.327192 +-6.7544,-0.00646018,0.236063 +-7.01178,-0.288525,0.16728 +-8.45745,-1.47267,0.51138 +-7.57146,-0.733859,0.427773 +-7.28321,-0.485685,0.391371 +-6.76108,-0.0128346,0.270557 +-7.92596,-1.03281,0.464943 +-8.80945,-2.77148,0.0648779 +-6.74802,-5.27019e-07,0.250129 +-6.75567,-0.00754177,0.265665 +-7.10026,-0.324681,0.363226 +-8.71582,-1.68407,0.531182 +-7.42983,-0.612623,0.410798 +-8.74075,-2.66278,0.0667895 +-6.78115,-0.0341353,0.218779 +-7.24089,-0.448727,0.385315 +-7.88554,-1.39851,0.0985396 +-6.75262,-0.00455295,0.262118 +-8.75044,-2.67805,0.0665156 +-7.54254,-0.939199,0.117917 +-6.74916,-0.00113358,0.256 +-6.85928,-0.117695,0.19434 +-6.74819,-0.000170179,0.247702 +-6.90609,-0.169171,0.184464 +-7.97912,-1.52869,0.0941234 +-7.66292,-1.09714,0.11045 +-6.81806,-0.0731904,0.205271 +-6.75421,-0.00626814,0.236268 +-7.15382,-0.372172,0.372049 +-8.38051,-1.40946,0.505175 +-8.07338,-1.66189,0.0899683 +-6.89476,-0.138849,0.321411 +-7.15695,-0.458536,0.14969 +-7.08231,-0.370298,0.158114 +-6.74866,-0.00063463,0.254481 +-6.85103,-0.108725,0.196316 +-7.00283,-0.278268,0.168549 +-7.11197,-0.405157,0.154632 +-7.05267,-0.335735,0.1618 +-6.74802,-2.30557e-06,0.25027 +-7.07971,-0.306381,0.359681 +-7.1007,-0.325076,0.363302 +-6.88295,-0.143613,0.189096 +-7.33333,-0.529254,0.398261 +-6.899,-0.161314,0.185839 +-8.74694,-2.67254,0.0666143 +-7.50255,-0.887534,0.120594 +-7.34694,-0.690384,0.132162 +-6.75752,-0.00935302,0.267484 +-6.76283,-0.0151043,0.22891 +-8.83104,-1.77797,0.539551 +-7.28501,-0.613713,0.137383 +-6.8652,-0.111475,0.313496 +-6.75948,-0.0112693,0.269232 +-10.6268,-3.22736,0.644542 +-6.80661,-0.0565087,0.294322 +-7.17676,-0.48223,0.147629 +-6.7683,-0.0198421,0.275718 +-7.48062,-0.859371,0.122109 +-6.79783,-0.0481727,0.290764 +-9.72314,-4.31592,0.0449046 +-6.7774,-0.0286218,0.281082 +-6.77075,-0.0232908,0.224003 +-6.84272,-0.0905166,0.306836 +-7.11815,-0.412456,0.15393 +-6.75261,-0.00453738,0.262097 +-7.49457,-0.877262,0.121142 +-7.85076,-0.969844,0.457571 +-6.81512,-0.0645674,0.29754 +-6.94743,-0.215367,0.177097 +-7.49578,-0.669225,0.41889 +-7.00671,-0.240895,0.346181 +-6.77723,-0.0284569,0.280989 +-6.74814,-0.000121524,0.251955 +-7.95113,-1.05383,0.467358 +-6.82298,-0.0719847,0.300346 +-9.1133,-3.26471,0.0572097 +-6.80308,-0.057235,0.210127 +-6.79873,-0.0490252,0.29114 +-7.8538,-0.972396,0.457874 +-7.15195,-0.37052,0.37175 +-6.93817,-0.178692,0.331798 +-7.18369,-0.398516,0.376735 +-6.86902,-0.128325,0.192112 +-6.86751,-0.113623,0.314146 +-7.62828,-1.05132,0.112515 +-6.74918,-0.00115492,0.256057 +-6.75265,-0.00468132,0.238102 +-7.55104,-0.716456,0.425416 +-6.97606,-0.247734,0.17252 +-6.87766,-0.137799,0.190219 +-6.76169,-0.0139343,0.229719 +-7.23598,-0.553754,0.141827 +-7.29895,-0.499386,0.393565 +-6.87791,-0.138067,0.190166 +-6.74929,-0.00126187,0.256333 +-7.51187,-0.682994,0.420813 +-8.05657,-1.14167,0.477209 +-7.89513,-1.41176,0.0980726 +-6.82509,-0.0807178,0.203191 +-7.30931,-0.508395,0.394994 +-7.05804,-0.287018,0.355831 +-6.75183,-0.00376853,0.261009 +-6.90209,-0.145599,0.323255 +-7.28986,-0.491475,0.392301 +-7.16935,-0.38588,0.374504 +-7.6679,-0.815768,0.438546 +-7.44579,-0.626347,0.412789 +-6.8576,-0.10441,0.311318 +-6.87731,-0.122713,0.316839 +-6.75566,-0.00774601,0.234767 +-6.7578,-0.00963089,0.267747 +-6.99739,-0.272034,0.169335 +-6.75229,-0.00431858,0.238565 +-7.01321,-0.246757,0.347449 +-6.74802,-1.89154e-06,0.250244 +-8.18778,-1.25049,0.488924 +-7.13254,-0.429497,0.152324 +-6.76493,-0.0172683,0.227497 +-7.25118,-0.572261,0.140418 +-6.7484,-0.000381749,0.253471 +-7.18272,-0.489386,0.147021 +-6.90326,-0.166032,0.185009 +-6.84379,-0.100876,0.198125 +-6.76758,-0.0200117,0.225835 +-6.8517,-0.0989073,0.309577 +-6.79306,-0.0436284,0.288706 +-6.95972,-0.229229,0.175086 +-6.75375,-0.00580086,0.23678 +-7.0259,-0.25818,0.349883 +-6.85924,-0.117649,0.19435 +-6.78139,-0.0324546,0.283177 +-7.74215,-0.87851,0.446475 +-7.50331,-0.888505,0.120543 +-6.82719,-0.0829706,0.202591 +-6.89431,-0.156135,0.186768 +-7.20422,-0.416571,0.379871 +-6.74806,-3.63598e-05,0.248937 +-6.80607,-0.0604097,0.209104 +-6.75671,-0.00855912,0.266709 +-6.75276,-0.00468506,0.262295 +-7.07271,-0.300127,0.358449 +-6.75597,-0.00783773,0.265975 +-7.33543,-0.531078,0.398544 +-6.84114,-0.0980117,0.198807 +-6.86244,-0.121142,0.193605 +-7.63622,-1.0618,0.112036 +-6.92031,-0.162341,0.327672 +-6.7558,-0.0076697,0.2658 +-7.60643,-1.02257,0.113851 +-7.50807,-0.679746,0.420361 +-6.77598,-0.0287368,0.221244 +-7.15236,-0.453062,0.150177 +-6.75227,-0.00420392,0.261637 +-6.93379,-0.174686,0.330803 +-6.76483,-0.0164805,0.273374 +-8.21258,-1.271,0.491077 +-6.7487,-0.000672226,0.254612 +-6.78372,-0.0368248,0.21763 +-6.86216,-0.108649,0.312633 +-9.72331,-4.31622,0.0449017 +-6.76474,-0.0170758,0.227618 +-6.75058,-0.00258437,0.241122 +-6.76179,-0.0135232,0.271116 +-6.79135,-0.0419964,0.287943 +-8.07775,-1.15928,0.47914 +-6.75725,-0.00936854,0.233284 +-8.09091,-1.68689,0.0892254 +-6.98157,-0.218169,0.341136 +-7.01577,-0.29311,0.166722 +-7.39474,-0.582385,0.406341 +-6.87969,-0.124919,0.317479 +-7.51279,-0.683782,0.420922 +-7.71513,-1.16675,0.107456 +-6.77755,-0.0287653,0.281162 +-7.24893,-0.455759,0.386483 +-6.76551,-0.0178713,0.22712 +-7.16908,-0.385639,0.374461 +-6.81451,-0.0694021,0.206364 +-7.3192,-0.655917,0.13445 +-6.75099,-0.00299216,0.240456 +-6.79618,-0.0499299,0.212608 +-6.77193,-0.0245149,0.223354 +-7.33512,-0.530808,0.398502 +-8.58321,-2.4175,0.0714575 +-6.9255,-0.167105,0.328892 +-6.8728,-0.132472,0.191273 +-6.9316,-0.172683,0.330301 +-6.84711,-0.0946197,0.30819 +-6.87822,-0.138407,0.1901 +-7.62499,-1.04699,0.112714 +-7.43195,-0.797304,0.125598 +-7.7516,-1.21576,0.105443 +-6.75417,-0.00622628,0.236313 +-6.92998,-0.171205,0.32993 +-6.76406,-0.0157325,0.272822 +-6.79484,-0.0485125,0.213112 +-6.82836,-0.0770473,0.302186 +-6.74811,-9.13933e-05,0.251695 +-7.11641,-0.410399,0.154127 +-7.54388,-0.710346,0.424583 +-7.73497,-1.19337,0.106353 +-7.2168,-0.530476,0.143651 +-6.77758,-0.0287998,0.281182 +-7.37646,-0.727305,0.129808 +-7.68535,-0.830532,0.440436 +-6.75862,-0.0104276,0.268483 +-7.8964,-1.00809,0.462075 +-6.93237,-0.173388,0.330478 +-7.37305,-0.723019,0.130077 +-6.75666,-0.00877169,0.233812 +-7.38325,-0.735827,0.129279 +-6.88293,-0.143588,0.1891 +-7.63739,-0.789904,0.435199 +-6.96337,-0.201658,0.337328 +-6.77067,-0.0232144,0.224044 +-6.78504,-0.0359616,0.284995 +-6.87441,-0.134236,0.190921 +-6.78558,-0.0387719,0.216827 +-7.29178,-0.622046,0.136791 +-8.07827,-1.66886,0.0897601 +-6.99351,-0.228977,0.343562 +-6.92529,-0.190552,0.180915 +-7.36858,-0.559784,0.402945 +-6.76356,-0.015852,0.22841 +-8.4388,-1.45736,0.509889 +-7.04193,-0.323278,0.163192 +-8.31422,-2.0115,0.0804856 +-6.77087,-0.0223233,0.27733 +-6.83784,-0.0859444,0.305295 +-6.9102,-0.173736,0.183684 +-6.84988,-0.107475,0.196598 +-9.79626,-2.55885,0.600979 +-7.99258,-1.54758,0.0935128 +-6.80078,-0.0548034,0.210932 +-7.45423,-0.633595,0.413832 +-7.05057,-0.28033,0.354475 +-6.93272,-0.198855,0.179604 +-8.86233,-2.85585,0.063454 +-6.78113,-0.0322104,0.283047 +-6.78267,-0.035721,0.218096 +-6.84668,-0.0942175,0.308058 +-7.25368,-0.45991,0.387169 +-7.98602,-1.08294,0.470664 +-6.7783,-0.031151,0.220113 +-8.47404,-1.48628,0.512697 +-7.59525,-0.754114,0.430484 +-6.80361,-0.0536635,0.293136 +-6.74884,-0.000813123,0.255076 +-7.30015,-0.500431,0.393731 +-7.59519,-1.00784,0.114548 +-8.09012,-1.68577,0.0892584 +-9.91814,-2.65702,0.607831 +-6.7891,-0.0398484,0.286917 +-6.88982,-0.134289,0.320144 +-6.81172,-0.061352,0.296279 +-6.78999,-0.0406977,0.287326 +-7.10697,-0.399261,0.155206 +-6.83344,-0.0896984,0.200852 +-7.82725,-1.31844,0.101454 +-9.86857,-2.6171,0.605065 +-7.42198,-0.605861,0.409809 +-7.06995,-0.297663,0.357961 +-6.83751,-0.0940855,0.19976 +-7.82736,-1.3186,0.101448 +-7.91924,-1.44516,0.0969134 +-6.83539,-0.0836494,0.304507 +-9.81311,-2.57243,0.601937 +-7.10406,-0.395832,0.155542 +-7.37419,-0.724457,0.129986 +-7.08919,-0.314832,0.361329 +-6.82956,-0.0855206,0.201923 +-6.82613,-0.0818353,0.202892 +-7.11565,-0.409498,0.154214 +-7.06527,-0.293479,0.357128 +-7.76049,-0.893965,0.448388 +-6.76012,-0.0118953,0.269772 +-6.99692,-0.232058,0.344245 +-8.32911,-1.36716,0.500944 +-6.75206,-0.00408307,0.238876 +-7.42634,-0.60962,0.410359 +-7.30072,-0.633068,0.136019 +-8.83728,-1.78305,0.539997 +-7.11788,-0.412141,0.15396 +-7.25548,-0.46148,0.387428 +-6.94774,-0.215715,0.177045 +-6.74902,-0.000993465,0.255614 +-6.91729,-0.159578,0.326957 +-7.33061,-0.670067,0.133499 +-8.05836,-1.14316,0.477373 +-7.25159,-0.45808,0.386867 +-7.75248,-0.887223,0.447555 +-6.77673,-0.0295135,0.220875 +-6.77151,-0.0229455,0.277721 +-7.37897,-0.730453,0.129612 +-6.93071,-0.196607,0.179955 +-7.0911,-0.380601,0.157062 +-6.84471,-0.0923718,0.307451 +-6.76203,-0.0137558,0.271301 +-6.74914,-0.00112371,0.244122 +-10.427,-3.0666,0.634695 +-6.76638,-0.0179856,0.274449 +-11.4121,-3.8604,0.680009 +-7.01033,-0.286861,0.167483 +-7.75058,-0.88562,0.447357 +-6.85686,-0.115059,0.194911 +-6.79954,-0.0498011,0.29148 +-6.74813,-0.000105655,0.248188 +-7.11559,-0.40943,0.15422 +-6.75625,-0.00834988,0.234198 +-6.77384,-0.0251946,0.279095 +-7.43216,-0.614627,0.41109 +-7.034,-0.265458,0.35141 +-6.74888,-0.000857037,0.255212 +-6.75274,-0.00466381,0.262267 +-6.82378,-0.0793094,0.203572 +-6.75102,-0.0030224,0.240409 +-6.75718,-0.00929819,0.233345 +-6.90259,-0.165293,0.185138 +-6.84233,-0.090148,0.306713 +-7.30011,-0.632306,0.136072 diff --git a/tests/testthat/resources/stan/loo_moment_match b/tests/testthat/resources/stan/loo_moment_match new file mode 100755 index 000000000..8073d9979 Binary files /dev/null and b/tests/testthat/resources/stan/loo_moment_match differ diff --git a/tests/testthat/test-csv.R b/tests/testthat/test-csv.R index 4c1215451..6dc331fee 100644 --- a/tests/testthat/test-csv.R +++ b/tests/testthat/test-csv.R @@ -3,9 +3,11 @@ context("read_cmdstan_csv") set_cmdstan_path() fit_bernoulli_optimize <- testing_fit("bernoulli", method = "optimize", seed = 1234) fit_bernoulli_variational <- testing_fit("bernoulli", method = "variational", seed = 123) +fit_bernoulli_laplace <- testing_fit("bernoulli", method = "laplace", seed = 123) fit_logistic_optimize <- testing_fit("logistic", method = "optimize", seed = 123) fit_logistic_variational <- testing_fit("logistic", method = "variational", seed = 123) fit_logistic_variational_short <- testing_fit("logistic", method = "variational", output_samples = 100, seed = 123) +fit_logistic_laplace <- testing_fit("logistic", method = "laplace", seed = 123) fit_bernoulli_diag_e_no_samples <- testing_fit("bernoulli", method = "sample", seed = 123, chains = 2, iter_sampling = 0, metric = "diag_e") @@ -368,6 +370,26 @@ test_that("read_cmdstan_csv() works for variational", { ) }) +test_that("read_cmdstan_csv() works for laplace", { + csv_output_1 <- read_cmdstan_csv(fit_bernoulli_laplace$output_files()) + csv_output_2 <- read_cmdstan_csv(fit_logistic_laplace$output_files()) + expect_equal(dim(csv_output_1$draws), c(1000, 3)) + expect_equal(dim(csv_output_2$draws), c(1000, 6)) + + csv_file <- test_path("resources", "csv", "bernoulli-1-laplace.csv") + csv_output_3 <- read_cmdstan_csv(csv_file) + expect_equal(as.numeric(csv_output_3$draws[1,"theta"]), 0.374086) + expect_equal(dim(csv_output_3$draws), c(1000, 3)) + expect_equal(csv_output_3$metadata$variables, c("lp__", "lp_approx__", "theta")) + + # variable filtering + csv_output_4 <- read_cmdstan_csv(fit_logistic_laplace$output_files(), variables = "beta") + expect_equal(posterior::variables(csv_output_4$draws), c("beta[1]", "beta[2]", "beta[3]")) + csv_output_5 <- read_cmdstan_csv(fit_logistic_laplace$output_files(), variables = c("alpha", "beta[2]")) + expect_equal(posterior::variables(csv_output_5$draws), c("alpha", "beta[2]")) +}) + + test_that("read_cmdstan_csv() works for generate_quantities", { csv_output_1 <- read_cmdstan_csv(fit_gq$output_files()) expect_equal(dim(csv_output_1$generated_quantities), c(1000, 2, 11)) @@ -417,28 +439,34 @@ test_that("read_cmdstan_csv() errors for files from different methods", { test_that("stan_variables and stan_variable_sizes works in read_cdmstan_csv()", { bern_opt <- read_cmdstan_csv(fit_bernoulli_optimize$output_files()) bern_vi <- read_cmdstan_csv(fit_bernoulli_variational$output_files()) + bern_lap <- read_cmdstan_csv(fit_bernoulli_laplace$output_files()) log_opt <- read_cmdstan_csv(fit_logistic_optimize$output_files()) log_vi <- read_cmdstan_csv(fit_logistic_variational$output_files()) + log_lap <- read_cmdstan_csv(fit_logistic_laplace$output_files()) bern_samp <- read_cmdstan_csv(fit_bernoulli_thin_1$output_files()) log_samp <- read_cmdstan_csv(fit_logistic_thin_1$output_files()) gq <- read_cmdstan_csv(fit_gq$output_files()) expect_equal(bern_opt$metadata$stan_variables, c("lp__", "theta")) expect_equal(bern_vi$metadata$stan_variables, c("lp__", "lp_approx__", "theta")) + expect_equal(bern_lap$metadata$stan_variables, c("lp__", "lp_approx__", "theta")) expect_equal(bern_samp$metadata$stan_variables, c("lp__", "theta")) expect_equal(log_opt$metadata$stan_variables, c("lp__", "alpha", "beta")) expect_equal(log_vi$metadata$stan_variables, c("lp__", "lp_approx__", "alpha", "beta")) + expect_equal(log_lap$metadata$stan_variables, c("lp__", "lp_approx__", "alpha", "beta")) expect_equal(log_samp$metadata$stan_variables, c("lp__", "alpha", "beta")) expect_equal(gq$metadata$stan_variables, c("y_rep","sum_y")) expect_equal(bern_opt$metadata$stan_variable_sizes, list(lp__ = 1, theta = 1)) expect_equal(bern_vi$metadata$stan_variable_sizes, list(lp__ = 1, lp_approx__ = 1, theta = 1)) + expect_equal(bern_lap$metadata$stan_variable_sizes, list(lp__ = 1, lp_approx__ = 1, theta = 1)) expect_equal(bern_samp$metadata$stan_variable_sizes, list(lp__ = 1, theta = 1)) expect_equal(log_opt$metadata$stan_variable_sizes, list(lp__ = 1, alpha = 1, beta = 3)) expect_equal(log_vi$metadata$stan_variable_sizes, list(lp__ = 1, lp_approx__ = 1, alpha = 1, beta = 3)) + expect_equal(log_lap$metadata$stan_variable_sizes, list(lp__ = 1, lp_approx__ = 1, alpha = 1, beta = 3)) expect_equal(log_samp$metadata$stan_variable_sizes, list(lp__ = 1, alpha = 1, beta = 3)) expect_equal(gq$metadata$stan_variable_sizes, list(y_rep = 10, sum_y = 1)) @@ -496,11 +524,13 @@ test_that("as_cmdstan_fit creates fitted model objects from csv", { fits <- list( mle = as_cmdstan_fit(fit_logistic_optimize$output_files()), vb = as_cmdstan_fit(fit_logistic_variational$output_files()), + laplace = as_cmdstan_fit(fit_logistic_laplace$output_files()), mcmc = as_cmdstan_fit(fit_logistic_thin_1$output_files()) ) for (class in names(fits)) { fit <- fits[[class]] - checkmate::expect_r6(fit, classes = paste0("CmdStan", toupper(class), "_CSV")) + class_name <- if (class == "laplace") "Laplace" else toupper(class) + checkmate::expect_r6(fit, classes = paste0("CmdStan", class_name, "_CSV")) expect_s3_class(fit$draws(), "draws") checkmate::expect_numeric(fit$lp()) expect_output(fit$print(), "variable") @@ -519,6 +549,9 @@ test_that("as_cmdstan_fit creates fitted model objects from csv", { if (class == "vb") { checkmate::expect_numeric(fit$lp_approx()) } + if (class == "laplace") { + checkmate::expect_numeric(fit$lp_approx()) + } for (method in unavailable_methods_CmdStanFit_CSV) { if (!(method == "time" && class == "mcmc")) { @@ -548,9 +581,11 @@ test_that("as_cmdstan_fit can check MCMC diagnostics", { test_that("read_cmdstan_csv reads seed correctly", { opt <- read_cmdstan_csv(fit_bernoulli_optimize$output_files()) vi <- read_cmdstan_csv(fit_bernoulli_variational$output_files()) + lap <- read_cmdstan_csv(fit_bernoulli_laplace$output_files()) smp <- read_cmdstan_csv(fit_bernoulli_diag_e_no_samples$output_files()) expect_equal(opt$metadata$seed, 1234) expect_equal(vi$metadata$seed, 123) + expect_equal(lap$metadata$seed, 123) expect_equal(smp$metadata$seed, 123) }) @@ -649,6 +684,56 @@ test_that("read_cmdstan_csv works with optimization and draws_list format", { }) +test_that("read_cmdstan_csv works with laplace and draws_array format", { + bern_laplace <- read_cmdstan_csv(fit_bernoulli_laplace$output_files()) + bern_laplace_array <- read_cmdstan_csv(fit_bernoulli_laplace$output_files(), format = "array") + + expect_equal(posterior::niterations(bern_laplace$draws), + posterior::niterations(bern_laplace_array$draws)) + expect_equal(posterior::nvariables(bern_laplace$draws), + posterior::nvariables(bern_laplace_array$draws)) + expect_equal(posterior::variables(bern_laplace$draws), + posterior::variables(bern_laplace_array$draws)) + + expect_equal(as.numeric(posterior::subset_draws(bern_laplace$draws, variable = "theta")), + as.numeric(posterior::subset_draws(bern_laplace_array$draws, variable = "theta"))) +}) + +test_that("read_cmdstan_csv works with laplace and draws_df format", { + bern_laplace <- read_cmdstan_csv(fit_bernoulli_laplace$output_files()) + bern_laplace_df <- read_cmdstan_csv(fit_bernoulli_laplace$output_files(), format = "df") + + expect_equal(posterior::niterations(bern_laplace$draws), + posterior::niterations(bern_laplace_df$draws)) + expect_equal(posterior::nchains(bern_laplace$draws), + posterior::nchains(bern_laplace_df$draws)) + expect_equal(posterior::nvariables(bern_laplace$draws), + posterior::nvariables(bern_laplace_df$draws)) + expect_equal(posterior::variables(bern_laplace$draws), + posterior::variables(bern_laplace_df$draws)) + + expect_equal(as.numeric(posterior::subset_draws(bern_laplace$draws, variable = "theta")), + as.numeric(posterior::subset_draws(bern_laplace_df$draws, variable = "theta")$theta)) +}) + +test_that("read_cmdstan_csv works with laplace and draws_list format", { + bern_laplace <- read_cmdstan_csv(fit_bernoulli_laplace$output_files()) + bern_laplace_list <- read_cmdstan_csv(fit_bernoulli_laplace$output_files(), format = "list") + + expect_equal(posterior::niterations(bern_laplace$draws), + posterior::niterations(bern_laplace_list$draws)) + expect_equal(posterior::nchains(bern_laplace$draws), + posterior::nchains(bern_laplace_list$draws)) + expect_equal(posterior::nvariables(bern_laplace$draws), + posterior::nvariables(bern_laplace_list$draws)) + expect_equal(posterior::variables(bern_laplace$draws), + posterior::variables(bern_laplace_list$draws)) + + expect_equal(as.numeric(posterior::subset_draws(bern_laplace$draws, variable = "theta")), + as.numeric(posterior::subset_draws(bern_laplace_list$draws, variable = "theta")[[1]]$theta)) + +}) + test_that("read_cmdstan_csv works with VI and draws_array format", { bern_vi <- read_cmdstan_csv(fit_bernoulli_variational$output_files()) bern_vi_array <- read_cmdstan_csv(fit_bernoulli_variational$output_files(), format = "array") @@ -662,7 +747,7 @@ test_that("read_cmdstan_csv works with VI and draws_array format", { expect_equal(as.numeric(posterior::subset_draws(bern_vi$draws, variable = "theta")), as.numeric(posterior::subset_draws(bern_vi_array$draws, variable = "theta"))) - }) +}) test_that("read_cmdstan_csv works with VI and draws_df format", { bern_vi <- read_cmdstan_csv(fit_bernoulli_variational$output_files()) @@ -796,6 +881,9 @@ test_that("read_cmdstan_csv works if no variables are specified", { expect_silent( read_cmdstan_csv(fit_bernoulli_variational$output_files(), variables = "", sampler_diagnostics = "") ) + expect_silent( + read_cmdstan_csv(fit_bernoulli_laplace$output_files(), variables = "", sampler_diagnostics = "") + ) expect_silent( read_cmdstan_csv(fit_bernoulli_thin_1$output_files(), variables = "", sampler_diagnostics = "") ) diff --git a/tests/testthat/test-fit-laplace.R b/tests/testthat/test-fit-laplace.R new file mode 100644 index 000000000..af94bebc8 --- /dev/null +++ b/tests/testthat/test-fit-laplace.R @@ -0,0 +1,66 @@ +context("fitted-vb") + +set_cmdstan_path() +fit_laplace <- testing_fit("logistic", method = "laplace", seed = 100) +PARAM_NAMES <- c("alpha", "beta[1]", "beta[2]", "beta[3]") + + +test_that("summary() and print() methods works after laplace", { + x <- fit_laplace$summary() + expect_s3_class(x, "draws_summary") + expect_equal(x$variable, c("lp__", "lp_approx__", PARAM_NAMES)) + + x <- fit_laplace$summary(variables = NULL, c("mean", "sd")) + expect_s3_class(x, "draws_summary") + expect_equal(x$variable, c("lp__", "lp_approx__", PARAM_NAMES)) + expect_equal(colnames(x), c("variable", "mean", "sd")) + + expect_output(expect_s3_class(fit_laplace$print(), "CmdStanLaplace"), "variable") + expect_output(fit_laplace$print(max_rows = 1), "# showing 1 of 6 rows") +}) + +test_that("draws() method returns posterior sample (reading csv works)", { + draws <- fit_laplace$draws() + expect_type(draws, "double") + expect_s3_class(draws, "draws_matrix") + expect_equal(posterior::variables(draws), c("lp__", "lp_approx__", PARAM_NAMES)) +}) + +test_that("lp(), lp_approx() methods return vectors (reading csv works)", { + lp <- fit_laplace$lp() + lg <- fit_laplace$lp_approx() + expect_type(lp, "double") + expect_type(lg, "double") + expect_equal(length(lp), nrow(fit_laplace$draws())) + expect_equal(length(lg), length(lp)) +}) + +test_that("time() method works after laplace", { + run_times <- fit_laplace$time() + checkmate::expect_list(run_times, names = "strict", any.missing = FALSE) + testthat::expect_named(run_times, c("total")) + checkmate::expect_number(run_times$total, finite = TRUE) +}) + +test_that("output() works for laplace", { + expect_output(fit_laplace$output(), "method = laplace") +}) + +test_that("draws() works for different formats", { + a <- fit_laplace$draws() + expect_true(posterior::is_draws_matrix(a)) + a <- fit_laplace$draws(format = "list") + expect_true(posterior::is_draws_list(a)) + a <- fit_laplace$draws(format = "array") + expect_true(posterior::is_draws_array(a)) + a <- fit_laplace$draws(format = "df") + expect_true(posterior::is_draws_df(a)) +}) + +test_that("draws() errors if invalid format", { + expect_error( + fit_laplace$draws(format = "bad_format"), + "The supplied draws format is not valid" + ) +}) + diff --git a/tests/testthat/test-fit-shared.R b/tests/testthat/test-fit-shared.R index 193a96645..85c69eb4d 100644 --- a/tests/testthat/test-fit-shared.R +++ b/tests/testthat/test-fit-shared.R @@ -7,9 +7,10 @@ fits[["sample"]] <- testing_fit("logistic", method = "sample", fits[["variational"]] <- testing_fit("logistic", method = "variational", seed = 123, save_latent_dynamics = TRUE) fits[["optimize"]] <- testing_fit("logistic", method = "optimize", seed = 123) +fits[["laplace"]] <- testing_fit("logistic", method = "laplace", seed = 123) fit_bern <- testing_fit("bernoulli", method = "sample", seed = 123) fits[["generate_quantities"]] <- testing_fit("bernoulli_ppc", method = "generate_quantities", fitted_params = fit_bern, seed = 123) -all_methods <- c("sample", "optimize", "variational", "generate_quantities") +all_methods <- c("sample", "optimize", "laplace", "variational", "generate_quantities") test_that("*_files() methods return the right number of paths", { @@ -110,9 +111,9 @@ test_that("saving data file works", { test_that("cmdstan_summary() and cmdstan_diagnose() work correctly", { for (method in all_methods) { fit <- fits[[method]] - if (method == "optimize") { - expect_error(fit$cmdstan_summary(), "Not available for optimize method") - expect_error(fit$cmdstan_diagnose(), "Not available for optimize method") + if (method %in% c("optimize", "laplace")) { + expect_error(fit$cmdstan_summary(), "Not available") + expect_error(fit$cmdstan_diagnose(), "Not available") } else if (method == "generate_quantities") { expect_error(fit$cmdstan_summary(), "Not available for generate_quantities method") expect_error(fit$cmdstan_diagnose(), "Not available for generate_quantities method") @@ -210,6 +211,7 @@ test_that("init() errors if no inits specified", { test_that("return_codes method works properly", { expect_equal(fits[["variational"]]$return_codes(), 0) expect_equal(fits[["optimize"]]$return_codes(), 0) + expect_equal(fits[["laplace"]]$return_codes(), 0) expect_equal(fits[["sample"]]$return_codes(), c(0,0,0,0)) expect_equal(fits[["generate_quantities"]]$return_codes(), c(0,0,0,0)) @@ -483,6 +485,14 @@ test_that("CmdStanModel created with exe_file works", { ) expect_equal(fit_optimize$mle(), fit_optimize_exe$mle()) + utils::capture.output( + fit_laplace <- mod$laplace(data = data_list, seed = 123) + ) + utils::capture.output( + fit_laplace_exe <- mod_exe$laplace(data = data_list, seed = 123) + ) + expect_equal(fit_laplace$draws(), fit_laplace_exe$draws()) + utils::capture.output( fit_variational <- mod$variational(data = data_list, seed = 123) ) @@ -524,7 +534,7 @@ test_that("CmdStanModel created with exe_file works", { test_that("code() works with all fitted model objects", { code_ans <- readLines(testing_stan_file("logistic")) - for (method in c("sample", "optimize", "variational")) { + for (method in c("sample", "optimize", "laplace", "variational")) { expect_identical(fits[[method]]$code(), code_ans) } code_ans_gq <- readLines(testing_stan_file("bernoulli_ppc")) diff --git a/tests/testthat/test-model-init.R b/tests/testthat/test-model-init.R index 221c8dfb3..131f4c6a0 100644 --- a/tests/testthat/test-model-init.R +++ b/tests/testthat/test-model-init.R @@ -22,6 +22,9 @@ test_that("all fitting methods work with provided init files", { expect_vb_output( mod$variational(data = data_list, init = init_json_1, seed = 123) ) + expect_laplace_output( + mod$laplace(data = data_list, init = init_json_1, seed = 123) + ) # broadcasting expect_sample_output( @@ -91,6 +94,11 @@ test_that("init can be a list of lists", { ) expect_length(fit$metadata()$init, 1) + expect_laplace_output( + fit <- mod_logistic$laplace(data = data_list_logistic, init = init_list[1], seed = 123) + ) + expect_length(fit$metadata()$init, 1) + expect_sample_output( fit <- mod_logistic$sample(data = data_list_logistic, chains = 2, init = init_list, seed = 123), num_chains = 2 diff --git a/tests/testthat/test-model-laplace.R b/tests/testthat/test-model-laplace.R new file mode 100644 index 000000000..b9b53fd96 --- /dev/null +++ b/tests/testthat/test-model-laplace.R @@ -0,0 +1,130 @@ +context("model-laplace") + +set_cmdstan_path() +mod <- testing_model("logistic") +data_list <- testing_data("logistic") + +# these are all valid for laplace() +ok_arg_values <- list( + data = data_list, + refresh = 100, + init = NULL, + seed = 12345, + mode = NULL, + draws = 100, + jacobian = TRUE, + opt_args = list( + algorithm = "lbfgs", + iter = 100, + init_alpha = 0.002, + tol_obj = 1e-11 + ) +) + +# using any of these should cause optimize() to error +bad_arg_values <- list( + data = "NOT_A_FILE", + refresh = -20, + init = "NOT_A_FILE", + seed = "NOT_A_SEED", + jacobian = 30, + draws = -10, + mode = 10 +) + +test_that("laplace() method errors for any invalid argument before calling cmdstan", { + for (nm in names(bad_arg_values)) { + args <- ok_arg_values + args[[nm]] <- bad_arg_values[[nm]] + expect_error(do.call(mod$laplace, args), regexp = nm, info = nm) + } + args <- ok_arg_values + args$opt_args <- list(iter = "NOT_A_NUMBER") + expect_error(do.call(mod$laplace, args), regexp = "Must be of type 'integerish'") +}) + +test_that("laplace() runs when all arguments specified validly", { + # specifying all arguments validly + expect_laplace_output(fit1 <- do.call(mod$laplace, ok_arg_values)) + expect_is(fit1, "CmdStanLaplace") + + # check that correct arguments were indeed passed to CmdStan + expect_equal(fit1$metadata()$refresh, ok_arg_values$refresh) + expect_equal(fit1$metadata()$jacobian, as.integer(ok_arg_values$jacobian)) + expect_equal(fit1$metadata()$draws, as.integer(ok_arg_values$draws)) + expect_equal(fit1$mode()$metadata()$jacobian, as.integer(ok_arg_values$jacobian)) + expect_equal(fit1$mode()$metadata()$init_alpha, ok_arg_values$opt_args$init_alpha) + expect_equal(fit1$mode()$metadata()$tol_obj, ok_arg_values$opt_args$tol_obj, tolerance = 0) + + # leaving all at default (except 'data') + expect_laplace_output(fit2 <- mod$laplace(data = data_list, seed = 123)) + expect_is(fit2, "CmdStanLaplace") +}) + +test_that("laplace() all valid 'mode' inputs give same results", { + mode <- mod$optimize(data = data_list, jacobian = TRUE, seed = 100, refresh = 0) + fit1 <- mod$laplace(data = data_list, mode = mode, seed = 100, refresh = 0) + fit2 <- mod$laplace(data = data_list, mode = mode$output_files(), seed = 100, refresh = 0) + fit3 <- mod$laplace(data = data_list, mode = NULL, seed = 100, refresh = 0) + + expect_is(fit1, "CmdStanLaplace") + expect_is(fit2, "CmdStanLaplace") + expect_is(fit3, "CmdStanLaplace") + expect_is(fit1$mode(), "CmdStanMLE") + expect_is(fit2$mode(), "CmdStanMLE") + expect_is(fit3$mode(), "CmdStanMLE") + expect_equal(fit1$mode()$mle(), fit2$mode()$mle()) + expect_equal(fit1$mode()$mle(), fit3$mode()$mle()) + expect_equal(fit1$lp(), fit2$lp()) + expect_equal(fit1$lp(), fit3$lp()) + expect_equal(fit1$lp_approx(), fit2$lp_approx()) + expect_equal(fit1$lp_approx(), fit3$lp_approx()) + expect_equal(fit1$draws(), fit2$draws()) + expect_equal(fit1$draws(), fit3$draws()) +}) + +test_that("laplace() allows choosing number of draws", { + fit <- mod$laplace(data = data_list, draws = 10, refresh = 0) + expect_equal(fit$metadata()$draws, 10) + expect_equal(posterior::ndraws(fit$draws()), 10) + + fit2 <- mod$laplace(data = data_list, draws = 100, refresh = 0) + expect_equal(fit2$metadata()$draws, 100) + expect_equal(posterior::ndraws(fit2$draws()), 100) +}) + +test_that("laplace() errors if jacobian arg doesn't match what optimize used", { + fit <- mod$optimize(data = data_list, jacobian = FALSE, refresh = 0) + expect_error( + mod$laplace(data = data_list, mode = fit, jacobian = TRUE), + "'jacobian' argument to optimize and laplace must match" + ) + expect_error( + mod$laplace(data = data_list, mode = fit, jacobian = TRUE), + "laplace was called with jacobian=TRUE\noptimize was run with jacobian=FALSE" + ) +}) + +test_that("laplace() errors with bad combinations of arguments", { + fit <- mod$optimize(data = data_list, jacobian = TRUE, refresh = 0) + expect_error( + mod$laplace(data = data_list, mode = mod, opt_args = list(iter = 10)), + "Cannot specify both 'opt_args' and 'mode' arguments." + ) + expect_error( + mod$laplace(data = data_list, mode = rnorm(10)), + "If not NULL or a CmdStanMLE object then 'mode' must be a path to a CSV file" + ) +}) + +test_that("laplace() errors if optimize() fails", { + mod_schools <- testing_model("schools") + expect_error( + expect_message( + mod_schools$laplace(data = testing_data("schools"), refresh = 0), + "Line search failed to achieve a sufficient decrease" + ), + "Optimization failed" + ) + +}) diff --git a/tests/testthat/test-model-variational.R b/tests/testthat/test-model-variational.R index 546bc0e5b..6c917d1e5 100644 --- a/tests/testthat/test-model-variational.R +++ b/tests/testthat/test-model-variational.R @@ -19,7 +19,7 @@ ok_arg_values <- list( adapt_iter = 51, tol_rel_obj = 0.011, eval_elbo = 101, - output_samples = 10, + draws = 10, save_latent_dynamics = FALSE ) @@ -38,7 +38,7 @@ bad_arg_values <- list( adapt_iter = -10, tol_rel_obj = -0.5, eval_elbo = -10, - output_samples = -10, + draws = -10, save_latent_dynamics = "NOT_LOGICAL" ) diff --git a/vignettes/cmdstanr.Rmd b/vignettes/cmdstanr.Rmd index 30bdd8cc7..0addfedfb 100644 --- a/vignettes/cmdstanr.Rmd +++ b/vignettes/cmdstanr.Rmd @@ -203,9 +203,9 @@ fit$summary( ``` ```{r, echo=FALSE} - # NOTE: the hack of using print.data.frame in chunks with echo=FALSE + # NOTE: the hack of using print.data.frame in chunks with echo=FALSE # is used because the pillar formatting of posterior draws_summary objects - # isn't playing nicely with pkgdown::build_articles(). + # isn't playing nicely with pkgdown::build_articles(). options(digits = 2) print.data.frame(fit$summary()) print.data.frame(fit$summary(variables = c("theta", "lp__"), "mean", "sd")) @@ -341,8 +341,8 @@ stanfit <- rstan::read_stan_csv(fit$output_files()) CmdStanR also supports running Stan's optimization algorithms and its algorithms for variational approximation of full Bayesian inference. These are run via the -`$optimize()` and `$variational()` methods, which are called in a similar way to -the `$sample()` method demonstrated above. +`$optimize()`, `$laplace()`, and `$variational()` methods, which are called in a +similar way to the `$sample()` method demonstrated above. ### Optimization @@ -362,6 +362,49 @@ mcmc_hist(fit$draws("theta")) + vline_at(fit_mle$mle("theta"), size = 1.5) ``` +For optimization, by default the mode is calculated without the Jacobian +adjustment for constrained variables, which shifts the mode due to the change of +variables. To include the Jacobian adjustment and obtain a maximum a posteriori +(MAP) estimate set `jacobian=TRUE`. See the +[Maximum Likelihood Estimation](https://mc-stan.org/docs/cmdstan-guide/maximum-likelihood-estimation.html) +section of the CmdStan User's Guide for more details. + +```{r optimize-map} +fit_map <- mod$optimize( + data = data_list, + jacobian = TRUE, + seed = 123 +) +``` + +### Laplace Approximation + +The [`$laplace()`](https://mc-stan.org/cmdstanr/reference/model-method-laplace.html) +produces a sample from a normal approximation centered at the mode of a +distribution in the unconstrained space. If the mode is a MAP estimate, the +samples provide an estimate of the mean and standard deviation of the posterior +distribution. If the mode is the MLE, the sample provides an estimate of the +standard error of the likelihood. Whether the mode is the MAP or MLE depends on +the value of the `jacobian` argument when running optimization. See the +[Laplace Sampling](https://mc-stan.org/docs/cmdstan-guide/laplace-sampling.html) +chapter of the CmdStan User's Guide for more details. + +Here we pass in the `fit_map` object from above as the `mode` argument. If +`mode` is omitted then optimization will be run internally before taking draws +from the normal approximation. + +```{r laplace} +fit_laplace <- mod$laplace( + mode = fit_map, + draws = 4000, + data = data_list, + seed = 123, + refresh = 1000 + ) +fit_laplace$summary("theta") +mcmc_hist(fit_laplace$draws("theta"), binwidth = 0.025) +``` + ### Variational Bayes We can run Stan's experimental variational Bayes algorithm (ADVI) using the @@ -369,27 +412,39 @@ We can run Stan's experimental variational Bayes algorithm (ADVI) using the method. ```{r variational} -fit_vb <- mod$variational(data = data_list, seed = 123, output_samples = 4000) +fit_vb <- mod$variational( + data = data_list, + seed = 123, + draws = 4000 +) fit_vb$print("theta") ``` -The `$draws()` method can be used to access the approximate posterior draws. -Let's extract the draws, make the same plot we made after MCMC, and compare the -two. In this trivial example the distributions look quite similar, although -the variational approximation slightly underestimates the posterior -standard deviation. +Let's extract the draws, make the same plot we made after MCMC and Laplace +approximation, and compare them all. In this simple example the distributions +are quite similar, but this will not always be the case. -```{r plot-variational-1, message = FALSE, fig.cap="Posterior from MCMC"} -mcmc_hist(fit$draws("theta"), binwidth = 0.025) +```{r plot-compare-vb, message = FALSE} +mcmc_hist(fit_vb$draws("theta"), binwidth = 0.025) + + ggplot2::labs(subtitle = "Approximate posterior from variational") + + ggplot2::xlim(0, 1) +``` +```{r plot-compare-laplace, message = FALSE} +mcmc_hist(fit_laplace$draws("theta"), binwidth = 0.025) + + ggplot2::labs(subtitle = "Approximate posterior from Laplace") + + ggplot2::xlim(0, 1) ``` -```{r plot-variational-2, message = FALSE, fig.cap="Posterior from variational"} -mcmc_hist(fit_vb$draws("theta"), binwidth = 0.025) +```{r plot-compare-mcmc, message = FALSE} +mcmc_hist(fit$draws("theta"), binwidth = 0.025) + + ggplot2::labs(subtitle = "Posterior from MCMC") + + ggplot2::xlim(0, 1) ``` -For more details on the `$optimize()` and `$variational()` methods, follow -these links to their documentation pages. +For more details on the `$optimize()`, `$laplace()` and `$variational()` +methods, follow these links to their documentation pages. * [`$optimize()`](https://mc-stan.org/cmdstanr/reference/model-method-optimize.html) +* [`$laplace()`](https://mc-stan.org/cmdstanr/reference/model-method-laplace.html) * [`$variational()`](https://mc-stan.org/cmdstanr/reference/model-method-variational.html) @@ -408,7 +463,7 @@ fit2 <- readRDS("fit.RDS") But if your model object is large, then [`$save_object()`](http://mc-stan.org/cmdstanr/reference/fit-method-save_object.html) -could take a long time. +could take a long time. [`$save_object()`](http://mc-stan.org/cmdstanr/reference/fit-method-save_object.html) reads the CmdStan results files into memory, stores them in the model object, and saves the object with `saveRDS()`. To speed up the process, you can emulate