diff --git a/DESCRIPTION b/DESCRIPTION index 726186ec..8f690fee 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -53,5 +53,6 @@ Suggests: loo (>= 2.0.0), rmarkdown, testthat (>= 2.1.0), - Rcpp + Rcpp, + bridgesampling VignetteBuilder: knitr diff --git a/R/fit.R b/R/fit.R index 5f07c5ee..ee8b3aa6 100644 --- a/R/fit.R +++ b/R/fit.R @@ -574,9 +574,9 @@ unconstrain_draws <- function(files = NULL, draws = NULL, } else { draws <- self$draws(inc_warmup = inc_warmup) } - + draws <- maybe_convert_draws_format(draws, "draws_matrix") - + chains <- posterior::nchains(draws) model_par_names <- self$metadata()$stan_variables[self$metadata()$stan_variables != "lp__"] @@ -595,7 +595,7 @@ unconstrain_draws <- function(files = NULL, draws = NULL, uncon_names <- private$model_methods_env_$unconstrained_param_names(private$model_methods_env_$model_ptr_, FALSE, FALSE) names(unconstrained) <- repair_variable_names(uncon_names) unconstrained$.nchains <- chains - + do.call(function(...) { create_draws_format(format, ...) }, unconstrained) } CmdStanFit$set("public", name = "unconstrain_draws", value = unconstrain_draws) @@ -1580,6 +1580,86 @@ loo <- function(variables = "log_lik", r_eff = TRUE, moment_match = FALSE, ...) } CmdStanMCMC$set("public", name = "loo", value = loo) +#' Marginal Log-Likelihood Approximation via Bridge Sampling +#' +#' @name fit-method-bridge_sampler +#' @aliases bridge_sampler +#' @description The `$bridge_sampler()` method computes the marginal likelihood +#' approximation using bridge sampling. This method requires the +#' \pkg{bridgesampling} package. +#' +#' @param method (character) The method to use for bridge sampling. Options are +#' `"normal"` (default) or `"warp3"`. +#' @param repetitions (integer) The number of repetitions for bridge sampling. +#' @param cores (integer) The number of cores to be used by the \pkg{bridgesampling} +#' package. Defaults to `1`. See the \pkg{bridgesampling} package documentation +#' for more details. +#' @param use_neff (logical) Whether to use the effective sample size (ESS) in +#' the optimal bridge function. Default is TRUE. If FALSE, the number of samples +#' is used instead. +#' @param maxiter (integer) The maximum number of iterations for bridge sampling. +#' @param silent (logical) Whether to suppress output from the bridge sampling +#' algorithm. Defaults to `FALSE`. +#' @param verbose (logical) Whether to print verbose output. Defaults to `FALSE`. +#' @param ... Other arguments passed to the bridge sampling function. +#' +#' @return The object returned by the bridge sampling function. +#' +#' @seealso The \pkg{bridgesampling} package website with +#' [documentation](https://cran.r-project.org/package=bridgesampling). +#' +#' @examples +#' +#' \dontrun{ +#' fit <- cmdstanr_example("logistic") +#' bridge_result <- fit$bridge_sampler() +#' print(bridge_result) +#' } +#' +bridge_sampler <- function(method = "normal", repetitions = 1, cores = 1, + use_neff = TRUE, maxiter = 1000, silent = FALSE, + verbose = FALSE, ...) { + require_suggested_package("bridgesampling") + self$init_model_methods() + + upars <- self$unconstrain_draws(format = "draws_array") + nr <- posterior::niterations(upars) + half_iter <- nr %/% 2 + + samples_4_iter <- posterior::subset_draws(upars, iteration = seq.int(from=half_iter + 1, nr)) + par_ess <- posterior::summarise_draws(samples_4_iter, "ess_median")$ess_median + neff <- posterior::quantile2(par_ess, 0.5) + + parameters <- attributes(upars)$dimnames$variable + transTypes <- rep("unbounded", length(parameters)) + names(transTypes) <- parameters + lb <- rep(-Inf, length(parameters)) + ub <- rep(Inf, length(parameters)) + names(lb) <- names(ub) <- parameters + + samples_4_fit <- posterior::subset_draws(upars, iteration = seq_len(half_iter)) + samples_4_fit <- posterior::as_draws_matrix(samples_4_fit) + samples_4_iter <- posterior::as_draws_matrix(samples_4_iter) + + colnames(samples_4_fit) <- paste0("trans_", parameters) + colnames(samples_4_iter) <- paste0("trans_", parameters) + + do.call(rlang::ns_env("bridgesampling")[[paste0(".bridge.sampler.", method)]], + args = list(samples_4_fit = samples_4_fit, + samples_4_iter = samples_4_iter, + neff = neff, + log_posterior = function(s.row, data) { data$fitobj$log_prob(s.row) }, + data = list(fitobj = self), + lb = lb, ub = ub, + param_types = rep("real", ncol(samples_4_fit)), + transTypes = transTypes, + repetitions = repetitions, cores = cores, + maxiter = maxiter, silent = silent, + verbose = verbose, r0 = 0.5, tol1 = 1e-10, tol2 = 1e-4, + ...)) +} +CmdStanMCMC$set("public", name = "bridge_sampler", value = bridge_sampler) + #' Extract sampler diagnostics after MCMC #' #' @name fit-method-sampler_diagnostics diff --git a/man/fit-method-bridge_sampler.Rd b/man/fit-method-bridge_sampler.Rd new file mode 100644 index 00000000..068414ac --- /dev/null +++ b/man/fit-method-bridge_sampler.Rd @@ -0,0 +1,62 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/fit.R +\name{fit-method-bridge_sampler} +\alias{fit-method-bridge_sampler} +\alias{bridge_sampler} +\title{Marginal Log-Likelihood Approximation via Bridge Sampling} +\usage{ +bridge_sampler( + method = "normal", + repetitions = 1, + cores = 1, + use_neff = TRUE, + maxiter = 1000, + silent = FALSE, + verbose = FALSE, + ... +) +} +\arguments{ +\item{method}{(character) The method to use for bridge sampling. Options are +\code{"normal"} (default) or \code{"warp3"}.} + +\item{repetitions}{(integer) The number of repetitions for bridge sampling.} + +\item{cores}{(integer) The number of cores to be used by the \pkg{bridgesampling} +package. Defaults to \code{1}. See the \pkg{bridgesampling} package documentation +for more details.} + +\item{use_neff}{(logical) Whether to use the effective sample size (ESS) in +the optimal bridge function. Default is TRUE. If FALSE, the number of samples +is used instead.} + +\item{maxiter}{(integer) The maximum number of iterations for bridge sampling.} + +\item{silent}{(logical) Whether to suppress output from the bridge sampling +algorithm. Defaults to \code{FALSE}.} + +\item{verbose}{(logical) Whether to print verbose output. Defaults to \code{FALSE}.} + +\item{...}{Other arguments passed to the bridge sampling function.} +} +\value{ +The object returned by the bridge sampling function. +} +\description{ +The \verb{$bridge_sampler()} method computes the marginal likelihood +approximation using bridge sampling. This method requires the +\pkg{bridgesampling} package. +} +\examples{ + +\dontrun{ +fit <- cmdstanr_example("logistic") +bridge_result <- fit$bridge_sampler() +print(bridge_result) +} + +} +\seealso{ +The \pkg{bridgesampling} package website with +\href{https://cran.r-project.org/package=bridgesampling}{documentation}. +} diff --git a/tests/testthat/test-fit-bridge_sampler.R b/tests/testthat/test-fit-bridge_sampler.R new file mode 100644 index 00000000..8c1de7ad --- /dev/null +++ b/tests/testthat/test-fit-bridge_sampler.R @@ -0,0 +1,82 @@ +context("bridge_sampler") + +set_cmdstan_path() + +mu <- 0 +tau2 <- 0.5 +sigma2 <- 1 + +n <- 20 +theta <- rnorm(n, mu, sqrt(tau2)) +y <- rnorm(n, theta, sqrt(sigma2)) + +mu0 <- 0 +tau20 <- 1 +alpha <- 1 +beta <- 1 + +dataH0 <- list(y = y, n = n, alpha = alpha, beta = beta, sigma2 = sigma2) +dataH1 <- list(y = y, n = n, mu0 = mu0, tau20 = tau20, alpha = alpha, + beta = beta, sigma2 = sigma2) + +stancodeH0 <- 'data { + int n; // number of observations + vector[n] y; // observations + real alpha; + real beta; + real sigma2; +} +parameters { + real tau2; // group-level variance + vector[n] theta; // participant effects +} +model { + target += inv_gamma_lpdf(tau2 | alpha, beta); + target += normal_lpdf(theta | 0, sqrt(tau2)); + target += normal_lpdf(y | theta, sqrt(sigma2)); +} +' +stancodeH1 <- 'data { + int n; // number of observations + vector[n] y; // observations + real mu0; + real tau20; + real alpha; + real beta; + real sigma2; +} +parameters { + real mu; + real tau2; // group-level variance + vector[n] theta; // participant effects +} +model { + target += normal_lpdf(mu | mu0, sqrt(tau20)); + target += inv_gamma_lpdf(tau2 | alpha, beta); + target += normal_lpdf(theta | mu, sqrt(tau2)); + target += normal_lpdf(y | theta, sqrt(sigma2)); +} +' +modH0 <- cmdstan_model(write_stan_file(stancodeH0), + force_recompile = TRUE) +modH1 <- cmdstan_model(write_stan_file(stancodeH1), + force_recompile = TRUE) + +fitH0 <- modH0$sample(data = dataH0, iter_warmup = 1000, iter_sampling = 50000, + chains = 3, parallel_chains = 3) +fitH1 <- modH1$sample(data = dataH1, iter_warmup = 1000, iter_sampling = 50000, + chains = 3, parallel_chains = 3) + +test_that("bridge_sampler method can be called", { + expect_no_error({bridgeH0 <- fitH0$bridge_sampler()}) + expect_no_error({bridgeH1 <- fitH1$bridge_sampler()}) + + expect_s3_class(bridgeH0, "bridge") + expect_s3_class(bridgeH1, "bridge") +}) + +test_that("bridge_sampler returns usable with bf and logml methods", { + expect_no_error({ bf_diff <- bridgesampling::bf(bridgeH0, bridgeH1) }) + expect_no_error({ logml_H0 <- bridgesampling::logml(bridgeH0) }) + expect_no_error({ logml_H1 <- bridgesampling::logml(bridgeH1) }) +})