Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pareto-smoothing updates #314

Merged
merged 17 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions R/convergence.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#' | [mcse_mean()] | Monte Carlo standard error for the mean |
#' | [mcse_quantile()] | Monte Carlo standard error for quantiles |
#' | [mcse_sd()] | Monte Carlo standard error for standard deviations |
#' | [pareto_khat()] | Pareto khat diagnostic for tail(s) |
#' | [pareto_diags()] | Additional diagnostics related to Pareto khat |
#' | [rhat_basic()] | Basic version of Rhat |
#' | [rhat()] | Improved, rank-based version of Rhat |
#' | [rhat_nested()] | Rhat for use with many short chains |
Expand Down
62 changes: 48 additions & 14 deletions R/pareto_smooth.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
#' the number of fractional moments that is useful for convergence
#' diagnostics. For further details see Vehtari et al. (2022).
#'
#' @family diagnostics
#' @template args-pareto
#' @template args-methods-dots
#' @template ref-vehtari-paretosmooth-2022
#' @return `khat` estimated Generalized Pareto Distribution shape parameter k
#'
#' @seealso [`pareto_diags`] for additional related diagnostics, and
#' [`pareto_smooth`] for Pareto smoothed draws.
#' @examples
#' mu <- extract_variable_matrix(example_draws(), "mu")
#' pareto_khat(mu)
Expand All @@ -25,6 +29,7 @@ pareto_khat.default <- function(x,
r_eff = NULL,
ndraws_tail = NULL,
verbose = FALSE,
log_weights = FALSE,
...) {
smoothed <- pareto_smooth.default(
x,
Expand All @@ -34,6 +39,7 @@ pareto_khat.default <- function(x,
verbose = verbose,
return_k = TRUE,
smooth_draws = FALSE,
log_weights = log_weights,
...)
return(smoothed$diagnostics)
}
Expand Down Expand Up @@ -65,6 +71,7 @@ pareto_khat.rvar <- function(x, ...) {
#' replacing tail draws by order statistics of a generalized Pareto
#' distribution fit to the tail(s).
#'
#' @family diagnostics
#' @template args-pareto
#' @template args-methods-dots
#' @template ref-vehtari-paretosmooth-2022
Expand Down Expand Up @@ -100,6 +107,8 @@ pareto_khat.rvar <- function(x, ...) {
#' when the sample size is increased, compared to the central limit
#' theorem convergence rate. See Appendix B in Vehtari et al. (2022).
#'
#' @seealso [`pareto_khat`] for only calculating khat, and
#' [`pareto_smooth`] for Pareto smoothed draws.
#' @examples
#' mu <- extract_variable_matrix(example_draws(), "mu")
#' pareto_diags(mu)
Expand All @@ -113,11 +122,12 @@ pareto_diags <- function(x, ...) UseMethod("pareto_diags")
#' @rdname pareto_diags
#' @export
pareto_diags.default <- function(x,
tail = c("both", "right", "left"),
r_eff = NULL,
ndraws_tail = NULL,
verbose = FALSE,
...) {
tail = c("both", "right", "left"),
r_eff = NULL,
ndraws_tail = NULL,
verbose = FALSE,
log_weights = FALSE,
...) {

smoothed <- pareto_smooth.default(
x,
Expand All @@ -128,6 +138,7 @@ pareto_diags.default <- function(x,
extra_diags = TRUE,
verbose = verbose,
smooth_draws = FALSE,
log_weights = FALSE,
...)

return(smoothed$diagnostics)
Expand Down Expand Up @@ -189,6 +200,8 @@ pareto_diags.rvar <- function(x, ...) {
#' Pareto smoothed estimates
#' * `convergence_rate`: Relative convergence rate for Pareto smoothed estimates
#'
#' @seealso [`pareto_khat`] for only calculating khat, and
#' [`pareto_diags`] for additional diagnostics.
#' @examples
#' mu <- extract_variable_matrix(example_draws(), "mu")
#' pareto_smooth(mu)
Expand Down Expand Up @@ -225,8 +238,8 @@ pareto_smooth.rvar <- function(x, return_k = TRUE, extra_diags = FALSE, ...) {
)
}
out <- list(
x = rvar(apply(draws_diags, margins, function(x) x[[1]]$x), nchains = nchains(x)),
diagnostics = diags
x = rvar(apply(draws_diags, margins, function(x) x[[1]]$x), nchains = nchains(x)),
diagnostics = diags
)
} else {
out <- rvar(apply(draws_diags, margins, function(x) x[[1]]), nchains = nchains(x))
Expand All @@ -243,20 +256,26 @@ pareto_smooth.default <- function(x,
return_k = TRUE,
extra_diags = FALSE,
verbose = FALSE,
log_weights = FALSE,
...) {

checkmate::assert_number(ndraws_tail, null.ok = TRUE)
checkmate::assert_number(r_eff, null.ok = TRUE)
checkmate::assert_logical(extra_diags)
checkmate::assert_logical(return_k)
checkmate::assert_logical(verbose)
checkmate::assert_logical(log_weights)

# check for infinite or na values
if (should_return_NA(x)) {
warning_no_call("Input contains infinite or NA values, Pareto smoothing not performed.")
return(list(x = x, diagnostics = NA_real_))
}

if (log_weights) {
tail = "right"
}

tail <- match.arg(tail)
S <- length(x)

Expand Down Expand Up @@ -290,6 +309,7 @@ pareto_smooth.default <- function(x,
x,
ndraws_tail = ndraws_tail,
tail = "left",
log_weights = log_weights,
...
)
left_k <- smoothed$k
Expand All @@ -299,6 +319,7 @@ pareto_smooth.default <- function(x,
x = smoothed$x,
ndraws_tail = ndraws_tail,
tail = "right",
log_weights = log_weights,
...
)
right_k <- smoothed$k
Expand Down Expand Up @@ -349,26 +370,32 @@ pareto_smooth.default <- function(x,
ndraws_tail,
smooth_draws = TRUE,
tail = c("right", "left"),
log_weights = FALSE,
...
) {

if (log_weights) {
# shift log values for safe exponentiation
x <- x - max(x)
}

tail <- match.arg(tail)

S <- length(x)
tail_ids <- seq(S - ndraws_tail + 1, S)


if (tail == "left") {
x <- -x
}

ord <- sort.int(x, index.return = TRUE)
draws_tail <- ord$x[tail_ids]
cutoff <- ord$x[min(tail_ids) - 1] # largest value smaller than tail values

cutoff <- ord$x[min(tail_ids) - 1] # largest value smaller than tail values

max_tail <- max(draws_tail)
min_tail <- min(draws_tail)

if (ndraws_tail >= 5) {
ord <- sort.int(x, index.return = TRUE)
if (abs(max_tail - min_tail) < .Machine$double.eps / 100) {
Expand All @@ -380,12 +407,19 @@ pareto_smooth.default <- function(x,
k <- NA
} else {
# save time not sorting since x already sorted
fit <- gpdfit(draws_tail - cutoff, sort_x = FALSE)
if (log_weights) {
draws_tail <- exp(draws_tail)
cutoff <- exp(cutoff)
}
fit <- gpdfit(draws_tail - cutoff, sort_x = FALSE, ...)
k <- fit$k
sigma <- fit$sigma
if (is.finite(k) && smooth_draws) {
p <- (seq_len(ndraws_tail) - 0.5) / ndraws_tail
smoothed <- qgeneralized_pareto(p = p, mu = cutoff, k = k, sigma = sigma)
if (log_weights) {
smoothed <- log(smoothed)
}
} else {
smoothed <- NULL
}
Expand Down Expand Up @@ -445,11 +479,11 @@ pareto_smooth.default <- function(x,
#' @noRd
ps_min_ss <- function(k, ...) {
if (k < 1) {
out <- 10^(1 / (1 - max(0, k)))
out <- 10^(1 / (1 - max(0, k)))
} else {
out <- Inf
out <- Inf
}
out
out
}


Expand Down
27 changes: 22 additions & 5 deletions R/weight_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#' @param log (logical) Are the weights passed already on the log scale? The
#' default is `FALSE`, that is, expecting `weights` to be on the standard
#' (non-log) scale.
#' @param pareto_smooth (logical) Should the weights be Pareto-smoothed?
#' The default is `FALSE`.
#' @template args-methods-dots
#' @template return-draws
#'
Expand Down Expand Up @@ -50,9 +52,12 @@ weight_draws <- function(x, weights, ...) {

#' @rdname weight_draws
#' @export
weight_draws.draws_matrix <- function(x, weights, log = FALSE, ...) {
weight_draws.draws_matrix <- function(x, weights, log = FALSE, pareto_smooth = FALSE, ...) {
log <- as_one_logical(log)
log_weights <- validate_weights(weights, x, log = log)
if (pareto_smooth) {
paul-buerkner marked this conversation as resolved.
Show resolved Hide resolved
log_weights <- pareto_smooth(log_weights, tail = "right", return_k = FALSE, log = TRUE)
}
if (".log_weight" %in% variables(x, reserved = TRUE)) {
# overwrite existing weights
x[, ".log_weight"] <- log_weights
Expand All @@ -66,9 +71,12 @@ weight_draws.draws_matrix <- function(x, weights, log = FALSE, ...) {

#' @rdname weight_draws
#' @export
weight_draws.draws_array <- function(x, weights, log = FALSE, ...) {
weight_draws.draws_array <- function(x, weights, log = FALSE, pareto_smooth = FALSE,...) {
log <- as_one_logical(log)
log_weights <- validate_weights(weights, x, log = log)
if (pareto_smooth) {
log_weights <- pareto_smooth(log_weights, tail = "right", return_k = FALSE, log = TRUE)
}
if (".log_weight" %in% variables(x, reserved = TRUE)) {
# overwrite existing weights
x[, , ".log_weight"] <- log_weights
Expand All @@ -82,18 +90,24 @@ weight_draws.draws_array <- function(x, weights, log = FALSE, ...) {

#' @rdname weight_draws
#' @export
weight_draws.draws_df <- function(x, weights, log = FALSE, ...) {
weight_draws.draws_df <- function(x, weights, log = FALSE, pareto_smooth = FALSE, ...) {
log <- as_one_logical(log)
log_weights <- validate_weights(weights, x, log = log)
if (pareto_smooth) {
log_weights <- pareto_smooth(log_weights, tail = "right", return_k = FALSE, log = TRUE)
}
x$.log_weight <- log_weights
x
}

#' @rdname weight_draws
#' @export
weight_draws.draws_list <- function(x, weights, log = FALSE, ...) {
weight_draws.draws_list <- function(x, weights, log = FALSE, pareto_smooth = FALSE, ...) {
log <- as_one_logical(log)
log_weights <- validate_weights(weights, x, log = log)
if (pareto_smooth) {
log_weights <- pareto_smooth(log_weights, tail = "right", return_k = FALSE, log = TRUE)
}
niterations <- niterations(x)
for (i in seq_len(nchains(x))) {
sel <- (1 + (i - 1) * niterations):(i * niterations)
Expand All @@ -104,9 +118,12 @@ weight_draws.draws_list <- function(x, weights, log = FALSE, ...) {

#' @rdname weight_draws
#' @export
weight_draws.draws_rvars <- function(x, weights, log = FALSE, ...) {
weight_draws.draws_rvars <- function(x, weights, log = FALSE, pareto_smooth = FALSE, ...) {
log <- as_one_logical(log)
log_weights <- validate_weights(weights, x, log = log)
if (pareto_smooth) {
log_weights <- pareto_smooth(log_weights, tail = "right", return_k = FALSE, log = TRUE)
}
x$.log_weight <- rvar(log_weights)
x
}
Expand Down
6 changes: 5 additions & 1 deletion man-roxygen/args-pareto.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@
#' @param ndraws_tail (numeric) number of draws for the tail. If
#' `ndraws_tail` is not specified, it will be calculated as
#' ceiling(3 * sqrt(length(x) / r_eff)) if length(x) > 225 and
#' length(x) / 5 otherwise (see Appendix H in Vehtari et al. (2022)).
#' length(x) / 5 otherwise (see Appendix H in Vehtari et
#' al. (2022)).
#' @param r_eff (numeric) relative effective sample size estimate. If
#' `r_eff` is omitted, it will be calculated assuming the draws are
#' from MCMC.
#' @param verbose (logical) Should diagnostic messages be printed? If
#' `TRUE`, messages related to Pareto diagnostics will be
#' printed. Default is `FALSE`.
#' @param log_weights (logical) Are the draws log weights? Default is
#' `FALSE`. If `TRUE` computation will take into account that the
#' draws are log weights, and only right tail will be smoothed.
2 changes: 1 addition & 1 deletion man-roxygen/ref-vehtari-paretosmooth-2022.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#' @references
#' Aki Vehtari, Daniel Simpson, Andrew Gelman, Yuling Yao and
#' Jonah Gabry (2022). Pareto Smoothed Importance Sampling.
#' arxiv:arXiv:1507.02646
#' arxiv:arXiv:1507.02646 (version 8)
2 changes: 2 additions & 0 deletions man/diagnostics.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/ess_basic.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/ess_bulk.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/ess_quantile.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/ess_sd.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/ess_tail.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/mcse_mean.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/mcse_quantile.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/mcse_sd.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading