diff --git a/NEWS.md b/NEWS.md index 7b9af2b6..a6909143 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,10 +2,11 @@ ### Enhancements +* Add `pareto_smooth` option to `weight_draws`, to Pareto smooth + weights before adding to a draws object. * Matrix multiplication of `rvar`s can now be done with the base matrix multiplication operator (`%*%`) instead of `%**%` in R >= 4.3. - # posterior 1.5.0 ### Enhancements diff --git a/R/convergence.R b/R/convergence.R index fa531193..cf895f07 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -21,6 +21,8 @@ #' | [mcse_mean()] | Monte Carlo standard error for the mean | #' | [mcse_quantile()] | Monte Carlo standard error for quantiles | #' | [mcse_sd()] | Monte Carlo standard error for standard deviations | +#' | [pareto_khat()] | Pareto khat diagnostic for tail(s) | +#' | [pareto_diags()] | Additional diagnostics related to Pareto khat | #' | [rhat_basic()] | Basic version of Rhat | #' | [rhat()] | Improved, rank-based version of Rhat | #' | [rhat_nested()] | Rhat for use with many short chains | diff --git a/R/pareto_smooth.R b/R/pareto_smooth.R index 228aede9..52ec6bba 100644 --- a/R/pareto_smooth.R +++ b/R/pareto_smooth.R @@ -5,10 +5,14 @@ #' the number of fractional moments that is useful for convergence #' diagnostics. For further details see Vehtari et al. (2022). #' +#' @family diagnostics #' @template args-pareto #' @template args-methods-dots #' @template ref-vehtari-paretosmooth-2022 #' @return `khat` estimated Generalized Pareto Distribution shape parameter k +#' +#' @seealso [`pareto_diags`] for additional related diagnostics, and +#' [`pareto_smooth`] for Pareto smoothed draws. #' @examples #' mu <- extract_variable_matrix(example_draws(), "mu") #' pareto_khat(mu) @@ -25,6 +29,7 @@ pareto_khat.default <- function(x, r_eff = NULL, ndraws_tail = NULL, verbose = FALSE, + are_log_weights = FALSE, ...) { smoothed <- pareto_smooth.default( x, @@ -34,6 +39,7 @@ pareto_khat.default <- function(x, verbose = verbose, return_k = TRUE, smooth_draws = FALSE, + are_log_weights = are_log_weights, ...) return(smoothed$diagnostics) } @@ -65,6 +71,7 @@ pareto_khat.rvar <- function(x, ...) { #' replacing tail draws by order statistics of a generalized Pareto #' distribution fit to the tail(s). #' +#' @family diagnostics #' @template args-pareto #' @template args-methods-dots #' @template ref-vehtari-paretosmooth-2022 @@ -100,6 +107,8 @@ pareto_khat.rvar <- function(x, ...) { #' when the sample size is increased, compared to the central limit #' theorem convergence rate. See Appendix B in Vehtari et al. (2022). #' +#' @seealso [`pareto_khat`] for only calculating khat, and +#' [`pareto_smooth`] for Pareto smoothed draws. #' @examples #' mu <- extract_variable_matrix(example_draws(), "mu") #' pareto_diags(mu) @@ -113,11 +122,12 @@ pareto_diags <- function(x, ...) UseMethod("pareto_diags") #' @rdname pareto_diags #' @export pareto_diags.default <- function(x, - tail = c("both", "right", "left"), - r_eff = NULL, - ndraws_tail = NULL, - verbose = FALSE, - ...) { + tail = c("both", "right", "left"), + r_eff = NULL, + ndraws_tail = NULL, + verbose = FALSE, + are_log_weights = FALSE, + ...) { smoothed <- pareto_smooth.default( x, @@ -128,6 +138,7 @@ pareto_diags.default <- function(x, extra_diags = TRUE, verbose = verbose, smooth_draws = FALSE, + are_log_weights = FALSE, ...) return(smoothed$diagnostics) @@ -189,6 +200,8 @@ pareto_diags.rvar <- function(x, ...) { #' Pareto smoothed estimates #' * `convergence_rate`: Relative convergence rate for Pareto smoothed estimates #' +#' @seealso [`pareto_khat`] for only calculating khat, and +#' [`pareto_diags`] for additional diagnostics. #' @examples #' mu <- extract_variable_matrix(example_draws(), "mu") #' pareto_smooth(mu) @@ -225,8 +238,8 @@ pareto_smooth.rvar <- function(x, return_k = TRUE, extra_diags = FALSE, ...) { ) } out <- list( - x = rvar(apply(draws_diags, margins, function(x) x[[1]]$x), nchains = nchains(x)), - diagnostics = diags + x = rvar(apply(draws_diags, margins, function(x) x[[1]]$x), nchains = nchains(x)), + diagnostics = diags ) } else { out <- rvar(apply(draws_diags, margins, function(x) x[[1]]), nchains = nchains(x)) @@ -238,25 +251,36 @@ pareto_smooth.rvar <- function(x, return_k = TRUE, extra_diags = FALSE, ...) { #' @export pareto_smooth.default <- function(x, tail = c("both", "right", "left"), - r_eff = NULL, + r_eff = 1, ndraws_tail = NULL, return_k = TRUE, extra_diags = FALSE, verbose = FALSE, + are_log_weights = FALSE, ...) { - checkmate::assert_number(ndraws_tail, null.ok = TRUE) - checkmate::assert_number(r_eff, null.ok = TRUE) - checkmate::assert_logical(extra_diags) - checkmate::assert_logical(return_k) - checkmate::assert_logical(verbose) + checkmate::expect_numeric(ndraws_tail, null.ok = TRUE) + checkmate::expect_numeric(r_eff, null.ok = TRUE) + extra_diags <- as_one_logical(extra_diags) + return_k <- as_one_logical(return_k) + verbose <- as_one_logical(verbose) + are_log_weights <- as_one_logical(are_log_weights) # check for infinite or na values if (should_return_NA(x)) { - warning_no_call("Input contains infinite or NA values, Pareto smoothing not performed.") - return(list(x = x, diagnostics = NA_real_)) + warning_no_call("Input contains infinite or NA values, or is constant. Fitting of generalized Pareto distribution not performed.") + if (!return_k) { + out <- x + } else { + out <- list(x = x, diagnostics = NA_real_) + } + return(out) } + if (are_log_weights) { + tail <- "right" + } + tail <- match.arg(tail) S <- length(x) @@ -290,6 +314,7 @@ pareto_smooth.default <- function(x, x, ndraws_tail = ndraws_tail, tail = "left", + are_log_weights = are_log_weights, ... ) left_k <- smoothed$k @@ -299,12 +324,14 @@ pareto_smooth.default <- function(x, x = smoothed$x, ndraws_tail = ndraws_tail, tail = "right", + are_log_weights = are_log_weights, ... ) right_k <- smoothed$k k <- max(left_k, right_k) x <- smoothed$x + } else { smoothed <- .pareto_smooth_tail( @@ -326,10 +353,11 @@ pareto_smooth.default <- function(x, if (verbose) { if (!extra_diags) { - diags_list <- .pareto_smooth_extra_diags(diags_list$khat, length(x)) + diags_list <- c(diags_list, .pareto_smooth_extra_diags(diags_list$khat, length(x))) } pareto_k_diagmsg( - diags = diags_list + diags = diags_list, + are_weights = are_log_weights ) } @@ -349,26 +377,32 @@ pareto_smooth.default <- function(x, ndraws_tail, smooth_draws = TRUE, tail = c("right", "left"), + are_log_weights = FALSE, ... ) { + if (are_log_weights) { + # shift log values for safe exponentiation + x <- x - max(x) + } + tail <- match.arg(tail) S <- length(x) tail_ids <- seq(S - ndraws_tail + 1, S) - if (tail == "left") { x <- -x } ord <- sort.int(x, index.return = TRUE) draws_tail <- ord$x[tail_ids] - cutoff <- ord$x[min(tail_ids) - 1] # largest value smaller than tail values + cutoff <- ord$x[min(tail_ids) - 1] # largest value smaller than tail values + max_tail <- max(draws_tail) min_tail <- min(draws_tail) - + if (ndraws_tail >= 5) { ord <- sort.int(x, index.return = TRUE) if (abs(max_tail - min_tail) < .Machine$double.eps / 100) { @@ -380,12 +414,19 @@ pareto_smooth.default <- function(x, k <- NA } else { # save time not sorting since x already sorted - fit <- gpdfit(draws_tail - cutoff, sort_x = FALSE) + if (are_log_weights) { + draws_tail <- exp(draws_tail) + cutoff <- exp(cutoff) + } + fit <- gpdfit(draws_tail - cutoff, sort_x = FALSE, ...) k <- fit$k sigma <- fit$sigma if (is.finite(k) && smooth_draws) { p <- (seq_len(ndraws_tail) - 0.5) / ndraws_tail smoothed <- qgeneralized_pareto(p = p, mu = cutoff, k = k, sigma = sigma) + if (are_log_weights) { + smoothed <- log(smoothed) + } } else { smoothed <- NULL } @@ -445,11 +486,11 @@ pareto_smooth.default <- function(x, #' @noRd ps_min_ss <- function(k, ...) { if (k < 1) { - out <- 10^(1 / (1 - max(0, k))) + out <- 10^(1 / (1 - max(0, k))) } else { - out <- Inf + out <- Inf } - out + out } @@ -506,27 +547,38 @@ ps_tail_length <- function(S, r_eff, ...) { #' #' Given S and scalar and k, form a diagnostic message string #' @param diags (numeric) named vector of diagnostic values +#' @param are_weights (logical) are the diagnostics for weights #' @param ... unused #' @return diagnostic message #' @noRd -pareto_k_diagmsg <- function(diags, ...) { +pareto_k_diagmsg <- function(diags, are_weights = FALSE, ...) { khat <- diags$khat min_ss <- diags$min_ss khat_threshold <- diags$khat_threshold convergence_rate <- diags$convergence_rate msg <- NULL - if (khat > 1) { - msg <- paste0(msg,'All estimates are unreliable. If the distribution of ratios is bounded,\n', - 'further draws may improve the estimates, but it is not possible to predict\n', - 'whether any feasible sample size is sufficient.') - } else { - if (khat > khat_threshold) { - msg <- paste0(msg, 'S is too small, and sample size larger than ', round(min_ss, 0), ' is needed for reliable results.\n') + + if (!are_weights) { + + if (khat > 1) { + msg <- paste0(msg, "All estimates are unreliable. If the distribution of draws is bounded,\n", + "further draws may improve the estimates, but it is not possible to predict\n", + "whether any feasible sample size is sufficient.") } else { - msg <- paste0(msg, 'To halve the RMSE, approximately ', round(2^(2/convergence_rate),1), ' times bigger S is needed.\n') + if (khat > khat_threshold) { + msg <- paste0(msg, "S is too small, and sample size larger than ", round(min_ss, 0), " is needed for reliable results.\n") + } else { + msg <- paste0(msg, "To halve the RMSE, approximately ", round(2^(2 / convergence_rate), 1), " times bigger S is needed.\n") + } + if (khat > 0.7) { + msg <- paste0(msg, "Bias dominates RMSE, and the variance based MCSE is underestimated.\n") + } } - if (khat > 0.7) { - msg <- paste0(msg, 'Bias dominates RMSE, and the variance based MCSE is underestimated.\n') + + } else { + + if (khat > khat_threshold || khat > 0.7) { + msg <- paste0(msg, "Pareto khat for weights is high (", round(khat, 1) ,"). This indicates a single or few weights dominate.\n", "Inference based on weighted draws will be unreliable.\n") } } message(msg) diff --git a/R/weight_draws.R b/R/weight_draws.R index 52326392..34494820 100644 --- a/R/weight_draws.R +++ b/R/weight_draws.R @@ -14,6 +14,8 @@ #' @param log (logical) Are the weights passed already on the log scale? The #' default is `FALSE`, that is, expecting `weights` to be on the standard #' (non-log) scale. +#' @param pareto_smooth (logical) Should the weights be Pareto-smoothed? +#' The default is `FALSE`. #' @template args-methods-dots #' @template return-draws #' @@ -43,6 +45,9 @@ #' head(weights(x)) #' head(weights(x, log=TRUE, normalize = FALSE)) # recover original log_wts #' +#' # add weights on log scale and Pareto smooth them +#' x <- weight_draws(x, weights = log_wts, log = TRUE, pareto_smooth = TRUE) +#' #' @export weight_draws <- function(x, weights, ...) { UseMethod("weight_draws") @@ -50,9 +55,15 @@ weight_draws <- function(x, weights, ...) { #' @rdname weight_draws #' @export -weight_draws.draws_matrix <- function(x, weights, log = FALSE, ...) { +weight_draws.draws_matrix <- function(x, weights, log = FALSE, pareto_smooth = FALSE, ...) { + + + pareto_smooth <- as_one_logical(pareto_smooth) log <- as_one_logical(log) log_weights <- validate_weights(weights, x, log = log) + if (pareto_smooth) { + log_weights <- pareto_smooth_log_weights(log_weights) + } if (".log_weight" %in% variables(x, reserved = TRUE)) { # overwrite existing weights x[, ".log_weight"] <- log_weights @@ -66,9 +77,14 @@ weight_draws.draws_matrix <- function(x, weights, log = FALSE, ...) { #' @rdname weight_draws #' @export -weight_draws.draws_array <- function(x, weights, log = FALSE, ...) { +weight_draws.draws_array <- function(x, weights, log = FALSE, pareto_smooth = FALSE, ...) { + + pareto_smooth <- as_one_logical(pareto_smooth) log <- as_one_logical(log) log_weights <- validate_weights(weights, x, log = log) + if (pareto_smooth) { + log_weights <- pareto_smooth_log_weights(log_weights) + } if (".log_weight" %in% variables(x, reserved = TRUE)) { # overwrite existing weights x[, , ".log_weight"] <- log_weights @@ -82,18 +98,28 @@ weight_draws.draws_array <- function(x, weights, log = FALSE, ...) { #' @rdname weight_draws #' @export -weight_draws.draws_df <- function(x, weights, log = FALSE, ...) { +weight_draws.draws_df <- function(x, weights, log = FALSE, pareto_smooth = FALSE, ...) { + + pareto_smooth <- as_one_logical(pareto_smooth) log <- as_one_logical(log) log_weights <- validate_weights(weights, x, log = log) + if (pareto_smooth) { + log_weights <- pareto_smooth_log_weights(log_weights) + } x$.log_weight <- log_weights x } #' @rdname weight_draws #' @export -weight_draws.draws_list <- function(x, weights, log = FALSE, ...) { +weight_draws.draws_list <- function(x, weights, log = FALSE, pareto_smooth = FALSE, ...) { + + pareto_smooth <- as_one_logical(pareto_smooth) log <- as_one_logical(log) log_weights <- validate_weights(weights, x, log = log) + if (pareto_smooth) { + log_weights <- pareto_smooth_log_weights(log_weights) + } niterations <- niterations(x) for (i in seq_len(nchains(x))) { sel <- (1 + (i - 1) * niterations):(i * niterations) @@ -104,9 +130,14 @@ weight_draws.draws_list <- function(x, weights, log = FALSE, ...) { #' @rdname weight_draws #' @export -weight_draws.draws_rvars <- function(x, weights, log = FALSE, ...) { +weight_draws.draws_rvars <- function(x, weights, log = FALSE, pareto_smooth = FALSE, ...) { + + pareto_smooth <- as_one_logical(pareto_smooth) log <- as_one_logical(log) log_weights <- validate_weights(weights, x, log = log) + if (pareto_smooth) { + log_weights <- pareto_smooth_log_weights(log_weights) + } x$.log_weight <- rvar(log_weights) x } @@ -161,3 +192,14 @@ validate_weights <- function(weights, draws, log = FALSE) { } weights } + + +pareto_smooth_log_weights <- function(log_weights) { + pareto_smooth( + log_weights, + tail = "right", + return_k = TRUE, + are_log_weights = TRUE, + extra_diags = TRUE + )$x +} diff --git a/man-roxygen/args-pareto.R b/man-roxygen/args-pareto.R index a406dbde..8a4d92b4 100644 --- a/man-roxygen/args-pareto.R +++ b/man-roxygen/args-pareto.R @@ -11,10 +11,14 @@ #' @param ndraws_tail (numeric) number of draws for the tail. If #' `ndraws_tail` is not specified, it will be calculated as #' ceiling(3 * sqrt(length(x) / r_eff)) if length(x) > 225 and -#' length(x) / 5 otherwise (see Appendix H in Vehtari et al. (2022)). +#' length(x) / 5 otherwise (see Appendix H in Vehtari et +#' al. (2022)). #' @param r_eff (numeric) relative effective sample size estimate. If -#' `r_eff` is omitted, it will be calculated assuming the draws are -#' from MCMC. +#' `r_eff` is NULL, it will be calculated assuming the draws are +#' from MCMC. Default is 1. #' @param verbose (logical) Should diagnostic messages be printed? If #' `TRUE`, messages related to Pareto diagnostics will be #' printed. Default is `FALSE`. +#' @param are_log_weights (logical) Are the draws log weights? Default is +#' `FALSE`. If `TRUE` computation will take into account that the +#' draws are log weights, and only right tail will be smoothed. diff --git a/man-roxygen/ref-vehtari-paretosmooth-2022.R b/man-roxygen/ref-vehtari-paretosmooth-2022.R index 267ec29a..30f526fc 100644 --- a/man-roxygen/ref-vehtari-paretosmooth-2022.R +++ b/man-roxygen/ref-vehtari-paretosmooth-2022.R @@ -1,4 +1,4 @@ #' @references #' Aki Vehtari, Daniel Simpson, Andrew Gelman, Yuling Yao and #' Jonah Gabry (2022). Pareto Smoothed Importance Sampling. -#' arxiv:arXiv:1507.02646 +#' arxiv:arXiv:1507.02646 (version 8) diff --git a/man/diagnostics.Rd b/man/diagnostics.Rd index 43d54186..f2ba4998 100644 --- a/man/diagnostics.Rd +++ b/man/diagnostics.Rd @@ -21,6 +21,8 @@ A list of available diagnostics and links to their individual help pages. \code{\link[=mcse_mean]{mcse_mean()}} \tab Monte Carlo standard error for the mean \cr \code{\link[=mcse_quantile]{mcse_quantile()}} \tab Monte Carlo standard error for quantiles \cr \code{\link[=mcse_sd]{mcse_sd()}} \tab Monte Carlo standard error for standard deviations \cr + \code{\link[=pareto_khat]{pareto_khat()}} \tab Pareto khat diagnostic for tail(s) \cr + \code{\link[=pareto_diags]{pareto_diags()}} \tab Additional diagnostics related to Pareto khat \cr \code{\link[=rhat_basic]{rhat_basic()}} \tab Basic version of Rhat \cr \code{\link[=rhat]{rhat()}} \tab Improved, rank-based version of Rhat \cr \code{\link[=rhat_nested]{rhat_nested()}} \tab Rhat for use with many short chains \cr diff --git a/man/ess_basic.Rd b/man/ess_basic.Rd index e300ad5e..867076ca 100755 --- a/man/ess_basic.Rd +++ b/man/ess_basic.Rd @@ -79,6 +79,8 @@ Other diagnostics: \code{\link{mcse_mean}()}, \code{\link{mcse_quantile}()}, \code{\link{mcse_sd}()}, +\code{\link{pareto_diags}()}, +\code{\link{pareto_khat}()}, \code{\link{rhat_basic}()}, \code{\link{rhat_nested}()}, \code{\link{rhat}()}, diff --git a/man/ess_bulk.Rd b/man/ess_bulk.Rd index adf3faf8..c1456be3 100755 --- a/man/ess_bulk.Rd +++ b/man/ess_bulk.Rd @@ -72,6 +72,8 @@ Other diagnostics: \code{\link{mcse_mean}()}, \code{\link{mcse_quantile}()}, \code{\link{mcse_sd}()}, +\code{\link{pareto_diags}()}, +\code{\link{pareto_khat}()}, \code{\link{rhat_basic}()}, \code{\link{rhat_nested}()}, \code{\link{rhat}()}, diff --git a/man/ess_quantile.Rd b/man/ess_quantile.Rd index 6bfc3cdf..aa85c909 100755 --- a/man/ess_quantile.Rd +++ b/man/ess_quantile.Rd @@ -81,6 +81,8 @@ Other diagnostics: \code{\link{mcse_mean}()}, \code{\link{mcse_quantile}()}, \code{\link{mcse_sd}()}, +\code{\link{pareto_diags}()}, +\code{\link{pareto_khat}()}, \code{\link{rhat_basic}()}, \code{\link{rhat_nested}()}, \code{\link{rhat}()}, diff --git a/man/ess_sd.Rd b/man/ess_sd.Rd index 2344211a..38475d2a 100755 --- a/man/ess_sd.Rd +++ b/man/ess_sd.Rd @@ -66,6 +66,8 @@ Other diagnostics: \code{\link{mcse_mean}()}, \code{\link{mcse_quantile}()}, \code{\link{mcse_sd}()}, +\code{\link{pareto_diags}()}, +\code{\link{pareto_khat}()}, \code{\link{rhat_basic}()}, \code{\link{rhat_nested}()}, \code{\link{rhat}()}, diff --git a/man/ess_tail.Rd b/man/ess_tail.Rd index f211f7aa..8f959718 100755 --- a/man/ess_tail.Rd +++ b/man/ess_tail.Rd @@ -72,6 +72,8 @@ Other diagnostics: \code{\link{mcse_mean}()}, \code{\link{mcse_quantile}()}, \code{\link{mcse_sd}()}, +\code{\link{pareto_diags}()}, +\code{\link{pareto_khat}()}, \code{\link{rhat_basic}()}, \code{\link{rhat_nested}()}, \code{\link{rhat}()}, diff --git a/man/mcse_mean.Rd b/man/mcse_mean.Rd index 9afaa7b7..c75935b1 100755 --- a/man/mcse_mean.Rd +++ b/man/mcse_mean.Rd @@ -63,6 +63,8 @@ Other diagnostics: \code{\link{ess_tail}()}, \code{\link{mcse_quantile}()}, \code{\link{mcse_sd}()}, +\code{\link{pareto_diags}()}, +\code{\link{pareto_khat}()}, \code{\link{rhat_basic}()}, \code{\link{rhat_nested}()}, \code{\link{rhat}()}, diff --git a/man/mcse_quantile.Rd b/man/mcse_quantile.Rd index cc4f9685..2d05f626 100755 --- a/man/mcse_quantile.Rd +++ b/man/mcse_quantile.Rd @@ -78,6 +78,8 @@ Other diagnostics: \code{\link{ess_tail}()}, \code{\link{mcse_mean}()}, \code{\link{mcse_sd}()}, +\code{\link{pareto_diags}()}, +\code{\link{pareto_khat}()}, \code{\link{rhat_basic}()}, \code{\link{rhat_nested}()}, \code{\link{rhat}()}, diff --git a/man/mcse_sd.Rd b/man/mcse_sd.Rd index 7e322864..671ef249 100755 --- a/man/mcse_sd.Rd +++ b/man/mcse_sd.Rd @@ -68,6 +68,8 @@ Other diagnostics: \code{\link{ess_tail}()}, \code{\link{mcse_mean}()}, \code{\link{mcse_quantile}()}, +\code{\link{pareto_diags}()}, +\code{\link{pareto_khat}()}, \code{\link{rhat_basic}()}, \code{\link{rhat_nested}()}, \code{\link{rhat}()}, diff --git a/man/pareto_diags.Rd b/man/pareto_diags.Rd index c2558a9a..9a1d5776 100644 --- a/man/pareto_diags.Rd +++ b/man/pareto_diags.Rd @@ -14,6 +14,7 @@ pareto_diags(x, ...) r_eff = NULL, ndraws_tail = NULL, verbose = FALSE, + are_log_weights = FALSE, ... ) @@ -39,17 +40,22 @@ pareto_diags(x, ...) The default is \code{"both"}.} \item{r_eff}{(numeric) relative effective sample size estimate. If -\code{r_eff} is omitted, it will be calculated assuming the draws are -from MCMC.} +\code{r_eff} is NULL, it will be calculated assuming the draws are +from MCMC. Default is 1.} \item{ndraws_tail}{(numeric) number of draws for the tail. If \code{ndraws_tail} is not specified, it will be calculated as ceiling(3 * sqrt(length(x) / r_eff)) if length(x) > 225 and -length(x) / 5 otherwise (see Appendix H in Vehtari et al. (2022)).} +length(x) / 5 otherwise (see Appendix H in Vehtari et +al. (2022)).} \item{verbose}{(logical) Should diagnostic messages be printed? If \code{TRUE}, messages related to Pareto diagnostics will be printed. Default is \code{FALSE}.} + +\item{are_log_weights}{(logical) Are the draws log weights? Default is +\code{FALSE}. If \code{TRUE} computation will take into account that the +draws are log weights, and only right tail will be smoothed.} } \value{ List of Pareto smoothing diagnostics: @@ -101,5 +107,25 @@ pareto_diags(d$Sigma) \references{ Aki Vehtari, Daniel Simpson, Andrew Gelman, Yuling Yao and Jonah Gabry (2022). Pareto Smoothed Importance Sampling. -arxiv:arXiv:1507.02646 +arxiv:arXiv:1507.02646 (version 8) +} +\seealso{ +\code{\link{pareto_khat}} for only calculating khat, and +\code{\link{pareto_smooth}} for Pareto smoothed draws. + +Other diagnostics: +\code{\link{ess_basic}()}, +\code{\link{ess_bulk}()}, +\code{\link{ess_quantile}()}, +\code{\link{ess_sd}()}, +\code{\link{ess_tail}()}, +\code{\link{mcse_mean}()}, +\code{\link{mcse_quantile}()}, +\code{\link{mcse_sd}()}, +\code{\link{pareto_khat}()}, +\code{\link{rhat_basic}()}, +\code{\link{rhat_nested}()}, +\code{\link{rhat}()}, +\code{\link{rstar}()} } +\concept{diagnostics} diff --git a/man/pareto_khat.Rd b/man/pareto_khat.Rd index c3383eb4..a4f91707 100644 --- a/man/pareto_khat.Rd +++ b/man/pareto_khat.Rd @@ -14,6 +14,7 @@ pareto_khat(x, ...) r_eff = NULL, ndraws_tail = NULL, verbose = FALSE, + are_log_weights = FALSE, ... ) @@ -39,17 +40,22 @@ pareto_khat(x, ...) The default is \code{"both"}.} \item{r_eff}{(numeric) relative effective sample size estimate. If -\code{r_eff} is omitted, it will be calculated assuming the draws are -from MCMC.} +\code{r_eff} is NULL, it will be calculated assuming the draws are +from MCMC. Default is 1.} \item{ndraws_tail}{(numeric) number of draws for the tail. If \code{ndraws_tail} is not specified, it will be calculated as ceiling(3 * sqrt(length(x) / r_eff)) if length(x) > 225 and -length(x) / 5 otherwise (see Appendix H in Vehtari et al. (2022)).} +length(x) / 5 otherwise (see Appendix H in Vehtari et +al. (2022)).} \item{verbose}{(logical) Should diagnostic messages be printed? If \code{TRUE}, messages related to Pareto diagnostics will be printed. Default is \code{FALSE}.} + +\item{are_log_weights}{(logical) Are the draws log weights? Default is +\code{FALSE}. If \code{TRUE} computation will take into account that the +draws are log weights, and only right tail will be smoothed.} } \value{ \code{khat} estimated Generalized Pareto Distribution shape parameter k @@ -70,5 +76,25 @@ pareto_khat(d$Sigma) \references{ Aki Vehtari, Daniel Simpson, Andrew Gelman, Yuling Yao and Jonah Gabry (2022). Pareto Smoothed Importance Sampling. -arxiv:arXiv:1507.02646 +arxiv:arXiv:1507.02646 (version 8) +} +\seealso{ +\code{\link{pareto_diags}} for additional related diagnostics, and +\code{\link{pareto_smooth}} for Pareto smoothed draws. + +Other diagnostics: +\code{\link{ess_basic}()}, +\code{\link{ess_bulk}()}, +\code{\link{ess_quantile}()}, +\code{\link{ess_sd}()}, +\code{\link{ess_tail}()}, +\code{\link{mcse_mean}()}, +\code{\link{mcse_quantile}()}, +\code{\link{mcse_sd}()}, +\code{\link{pareto_diags}()}, +\code{\link{rhat_basic}()}, +\code{\link{rhat_nested}()}, +\code{\link{rhat}()}, +\code{\link{rstar}()} } +\concept{diagnostics} diff --git a/man/pareto_smooth.Rd b/man/pareto_smooth.Rd index 259421de..c0e6f017 100644 --- a/man/pareto_smooth.Rd +++ b/man/pareto_smooth.Rd @@ -13,11 +13,12 @@ pareto_smooth(x, ...) \method{pareto_smooth}{default}( x, tail = c("both", "right", "left"), - r_eff = NULL, + r_eff = 1, ndraws_tail = NULL, return_k = TRUE, extra_diags = FALSE, verbose = FALSE, + are_log_weights = FALSE, ... ) } @@ -50,17 +51,22 @@ returned. Default is \code{FALSE}.} The default is \code{"both"}.} \item{r_eff}{(numeric) relative effective sample size estimate. If -\code{r_eff} is omitted, it will be calculated assuming the draws are -from MCMC.} +\code{r_eff} is NULL, it will be calculated assuming the draws are +from MCMC. Default is 1.} \item{ndraws_tail}{(numeric) number of draws for the tail. If \code{ndraws_tail} is not specified, it will be calculated as ceiling(3 * sqrt(length(x) / r_eff)) if length(x) > 225 and -length(x) / 5 otherwise (see Appendix H in Vehtari et al. (2022)).} +length(x) / 5 otherwise (see Appendix H in Vehtari et +al. (2022)).} \item{verbose}{(logical) Should diagnostic messages be printed? If \code{TRUE}, messages related to Pareto diagnostics will be printed. Default is \code{FALSE}.} + +\item{are_log_weights}{(logical) Are the draws log weights? Default is +\code{FALSE}. If \code{TRUE} computation will take into account that the +draws are log weights, and only right tail will be smoothed.} } \value{ Either a vector \code{x} of smoothed values or a named list @@ -91,5 +97,9 @@ pareto_smooth(d$Sigma) \references{ Aki Vehtari, Daniel Simpson, Andrew Gelman, Yuling Yao and Jonah Gabry (2022). Pareto Smoothed Importance Sampling. -arxiv:arXiv:1507.02646 +arxiv:arXiv:1507.02646 (version 8) +} +\seealso{ +\code{\link{pareto_khat}} for only calculating khat, and +\code{\link{pareto_diags}} for additional diagnostics. } diff --git a/man/rhat.Rd b/man/rhat.Rd index fed2c14a..263561cd 100755 --- a/man/rhat.Rd +++ b/man/rhat.Rd @@ -67,6 +67,8 @@ Other diagnostics: \code{\link{mcse_mean}()}, \code{\link{mcse_quantile}()}, \code{\link{mcse_sd}()}, +\code{\link{pareto_diags}()}, +\code{\link{pareto_khat}()}, \code{\link{rhat_basic}()}, \code{\link{rhat_nested}()}, \code{\link{rstar}()} diff --git a/man/rhat_basic.Rd b/man/rhat_basic.Rd index 8a94efb3..16ffd332 100755 --- a/man/rhat_basic.Rd +++ b/man/rhat_basic.Rd @@ -75,6 +75,8 @@ Other diagnostics: \code{\link{mcse_mean}()}, \code{\link{mcse_quantile}()}, \code{\link{mcse_sd}()}, +\code{\link{pareto_diags}()}, +\code{\link{pareto_khat}()}, \code{\link{rhat_nested}()}, \code{\link{rhat}()}, \code{\link{rstar}()} diff --git a/man/rhat_nested.Rd b/man/rhat_nested.Rd index 2e23242d..f2536efd 100644 --- a/man/rhat_nested.Rd +++ b/man/rhat_nested.Rd @@ -83,6 +83,8 @@ Other diagnostics: \code{\link{mcse_mean}()}, \code{\link{mcse_quantile}()}, \code{\link{mcse_sd}()}, +\code{\link{pareto_diags}()}, +\code{\link{pareto_khat}()}, \code{\link{rhat_basic}()}, \code{\link{rhat}()}, \code{\link{rstar}()} diff --git a/man/rstar.Rd b/man/rstar.Rd index 87e8e372..c9479902 100644 --- a/man/rstar.Rd +++ b/man/rstar.Rd @@ -115,6 +115,8 @@ Other diagnostics: \code{\link{mcse_mean}()}, \code{\link{mcse_quantile}()}, \code{\link{mcse_sd}()}, +\code{\link{pareto_diags}()}, +\code{\link{pareto_khat}()}, \code{\link{rhat_basic}()}, \code{\link{rhat_nested}()}, \code{\link{rhat}()} diff --git a/man/weight_draws.Rd b/man/weight_draws.Rd index 4601c983..d866d466 100644 --- a/man/weight_draws.Rd +++ b/man/weight_draws.Rd @@ -11,15 +11,15 @@ \usage{ weight_draws(x, weights, ...) -\method{weight_draws}{draws_matrix}(x, weights, log = FALSE, ...) +\method{weight_draws}{draws_matrix}(x, weights, log = FALSE, pareto_smooth = FALSE, ...) -\method{weight_draws}{draws_array}(x, weights, log = FALSE, ...) +\method{weight_draws}{draws_array}(x, weights, log = FALSE, pareto_smooth = FALSE, ...) -\method{weight_draws}{draws_df}(x, weights, log = FALSE, ...) +\method{weight_draws}{draws_df}(x, weights, log = FALSE, pareto_smooth = FALSE, ...) -\method{weight_draws}{draws_list}(x, weights, log = FALSE, ...) +\method{weight_draws}{draws_list}(x, weights, log = FALSE, pareto_smooth = FALSE, ...) -\method{weight_draws}{draws_rvars}(x, weights, log = FALSE, ...) +\method{weight_draws}{draws_rvars}(x, weights, log = FALSE, pareto_smooth = FALSE, ...) } \arguments{ \item{x}{(draws) A \code{draws} object or another \R object for which the method @@ -35,6 +35,9 @@ can be returned via the \code{\link[=weights.draws]{weights.draws()}} method lat \item{log}{(logical) Are the weights passed already on the log scale? The default is \code{FALSE}, that is, expecting \code{weights} to be on the standard (non-log) scale.} + +\item{pareto_smooth}{(logical) Should the weights be Pareto-smoothed? +The default is \code{FALSE}.} } \value{ A \code{draws} object of the same class as \code{x}. @@ -70,6 +73,9 @@ x <- weight_draws(x, weights = log_wts, log = TRUE) head(weights(x)) head(weights(x, log=TRUE, normalize = FALSE)) # recover original log_wts +# add weights on log scale and Pareto smooth them +x <- weight_draws(x, weights = log_wts, log = TRUE, pareto_smooth = TRUE) + } \seealso{ \code{\link[=weights.draws]{weights.draws()}}, \code{\link[=resample_draws]{resample_draws()}} diff --git a/man/weights.draws.Rd b/man/weights.draws.Rd index 6b2a46ff..1a47788e 100644 --- a/man/weights.draws.Rd +++ b/man/weights.draws.Rd @@ -48,6 +48,9 @@ x <- weight_draws(x, weights = log_wts, log = TRUE) head(weights(x)) head(weights(x, log=TRUE, normalize = FALSE)) # recover original log_wts +# add weights on log scale and Pareto smooth them +x <- weight_draws(x, weights = log_wts, log = TRUE, pareto_smooth = TRUE) + } \seealso{ \code{\link{weight_draws}}, \code{\link{resample_draws}} diff --git a/tests/testthat/test-pareto_smooth.R b/tests/testthat/test-pareto_smooth.R index dff22ee1..6b67d2b0 100644 --- a/tests/testthat/test-pareto_smooth.R +++ b/tests/testthat/test-pareto_smooth.R @@ -73,7 +73,7 @@ test_that("pareto_khat diagnostics messages are as expected", { diags$khat <- 1.1 expect_message(pareto_k_diagmsg(diags), - paste0('All estimates are unreliable. If the distribution of ratios is bounded,\n', + paste0('All estimates are unreliable. If the distribution of draws is bounded,\n', 'further draws may improve the estimates, but it is not possible to predict\n', 'whether any feasible sample size is sufficient.')) @@ -192,3 +192,16 @@ test_that("pareto_smooth returns x with smoothed tail", { expect_false(isTRUE(all.equal(sort(tau), sort(tau_smoothed)))) }) + +test_that("pareto_smooth works for log_weights", { + w <- c(1:25, 1e3, 1e3, 1e3) + lw <- log(w) + + ps <- pareto_smooth(lw, are_log_weights = TRUE, verbose = FALSE, ndraws_tail = 10) + + # only right tail is smoothed + expect_equal(ps$x[1:15], lw[1:15]) + + expect_true(ps$diagnostics$khat > 0.7) + +}) diff --git a/tests/testthat/test-weight_draws.R b/tests/testthat/test-weight_draws.R index 87337439..fb6e6cc3 100644 --- a/tests/testthat/test-weight_draws.R +++ b/tests/testthat/test-weight_draws.R @@ -63,7 +63,6 @@ test_that("weight_draws works on draws_rvars", { expect_equal(weights2, weights) }) - # conversion preserves weights -------------------------------------------- test_that("conversion between formats preserves weights", { @@ -88,3 +87,13 @@ test_that("conversion between formats preserves weights", { expect_equal(as_draws_rvars(draws[[!!type]]), draws$rvars) } }) + +# pareto smoothing ---------------- + +test_that("pareto smoothing smooths weights in weight_draws", { + x <- example_draws() + lw <- sort(log(abs(rt(ndraws(x), 1)))) + weighted <- weight_draws(x, lw, pareto_smooth = FALSE, log = TRUE) + smoothed <- weight_draws(x, lw, pareto_smooth = TRUE, log = TRUE) + expect_false(all(weights(weighted) == weights(smoothed))) +})