Skip to content

Commit

Permalink
Merge branch 'master' of github.com:n-kall/posterior
Browse files Browse the repository at this point in the history
  • Loading branch information
n-kall committed Jan 9, 2024
2 parents 12b0077 + 627d2e8 commit 4f4613c
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 60 deletions.
9 changes: 9 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,16 @@ S3method(order_draws,draws_list)
S3method(order_draws,draws_matrix)
S3method(order_draws,draws_rvars)
S3method(order_draws,rvar)
S3method(pareto_convergence_rate,default)
S3method(pareto_convergence_rate,rvar)
S3method(pareto_diags,default)
S3method(pareto_diags,rvar)
S3method(pareto_khat,default)
S3method(pareto_khat,rvar)
S3method(pareto_khat_threshold,default)
S3method(pareto_khat_threshold,rvar)
S3method(pareto_min_ss,default)
S3method(pareto_min_ss,rvar)
S3method(pareto_smooth,default)
S3method(pareto_smooth,rvar)
S3method(pillar_shaft,rvar)
Expand Down Expand Up @@ -455,8 +461,11 @@ export(ndraws)
export(niterations)
export(nvariables)
export(order_draws)
export(pareto_convergence_rate)
export(pareto_diags)
export(pareto_khat)
export(pareto_khat_threshold)
export(pareto_min_ss)
export(pareto_smooth)
export(quantile2)
export(r_scale)
Expand Down
59 changes: 31 additions & 28 deletions R/pareto_smooth.R
Original file line number Diff line number Diff line change
Expand Up @@ -181,24 +181,23 @@ pareto_diags.rvar <- function(x, ...) {
#'
#' @template args-pareto
#' @param return_k (logical) Should the Pareto khat be included in
#' output? If `TRUE`, output will be a list containing of smoothed
#' draws and diagnostics. Default is `TRUE`.
#' output? If `TRUE`, output will be a list containing smoothed
#' draws and diagnostics, otherwise it will be a numeric of the
#' smoothed draws. Default is `FALSE`.
#' @param extra_diags (logical) Should extra Pareto khat diagnostics
#' be included in output? If `TRUE`, `min_ss`, `khat_threshold` and
#' `convergence_rate` for the estimated k value will be
#' returned. Default is `FALSE`.
#' @template args-methods-dots
#' @template ref-vehtari-paretosmooth-2022
#' @return Either a vector `x` of smoothed values or a named list
#' containing the vector `x` and a named list `diagnostics` containing Pareto smoothing
#' diagnostics:
#' * `khat`: estimated Pareto k shape parameter, and
#' optionally
#' * `min_ss`: minimum sample size for reliable Pareto
#' smoothed estimate
#' * `khat_threshold`: khat-threshold for reliable
#' containing the vector `x` and a named list `diagnostics`
#' containing Pareto smoothing diagnostics: * `khat`: estimated
#' Pareto k shape parameter, and optionally * `min_ss`: minimum
#' sample size for reliable Pareto smoothed estimate *
#' `khat_threshold`: khat-threshold for reliable Pareto smoothed
#' estimates * `convergence_rate`: Relative convergence rate for
#' Pareto smoothed estimates
#' * `convergence_rate`: Relative convergence rate for Pareto smoothed estimates
#'
#' @seealso [`pareto_khat`] for only calculating khat, and
#' [`pareto_diags`] for additional diagnostics.
Expand Down Expand Up @@ -259,8 +258,8 @@ pareto_smooth.default <- function(x,
are_log_weights = FALSE,
...) {

checkmate::expect_numeric(ndraws_tail, null.ok = TRUE)
checkmate::expect_numeric(r_eff, null.ok = TRUE)
checkmate::assert_numeric(ndraws_tail, null.ok = TRUE)
checkmate::assert_numeric(r_eff, null.ok = TRUE)
extra_diags <- as_one_logical(extra_diags)
return_k <- as_one_logical(return_k)
verbose <- as_one_logical(verbose)
Expand Down Expand Up @@ -370,55 +369,59 @@ pareto_smooth.default <- function(x,
return(out)
}

#' Threshold for Pareto k-hat diagnostic based on sample size
#'
#' @param x
#' @param ...
#' @return
#' @rdname pareto_diags
#' @export
pareto_khat_threshold <- function(x, ...) {
UseMethod("pareto_khat_threshold")
}


#' @rdname pareto_diags
#' @export
pareto_khat_threshold.default <- function(x, ...) {
c(khat_threshold = ps_khat_threshold(length(x)))
}

#' @rdname pareto_diags
#' @export
pareto_khat_threshold.rvar <- function(x, ...) {
c(khat_threshold = ps_khat_threshold(ndraws(x)))
}

#' Minimum sample size for Pareto diagnostics
#'
#' @param ...
#' @return
#' @rdname pareto_diags
#' @export
pareto_min_ss <- function(x, ...) {
UseMethod("pareto_min_ss")
}

#' @rdname pareto_diags
#' @export
pareto_min_ss.default <- function(x, ...) {
k <- pareto_khat(x)$k
c(min_ss = ps_min_ss(k))
}

#' @rdname pareto_diags
#' @export
pareto_min_ss.rvar <- function(x, ...) {
k <- pareto_khat(x)$k
c(min_ss = ps_min_ss(k))
}

#' Convergence rate based on Pareto diagnostics
#'
#' @param ...
#' @return
#' @rdname pareto_diags
#' @export
pareto_convergence_rate <- function(x, ...) {
UseMethod("pareto_convergence_rate")
}

#' @rdname pareto_diags
#' @export
pareto_convergence_rate.default <- function(x, ...) {
k <- pareto_khat(x)$khat
c(convergence_rate = ps_convergence_rate(k, length(x)))
}

#' @rdname pareto_diags
#' @export
pareto_convergence_rate.rvar <- function(x, ...) {
k <- pareto_khat(x)
c(convergence_rate = ps_convergence_rate(k, ndraws(x)))
Expand Down Expand Up @@ -618,7 +621,7 @@ pareto_k_diagmsg <- function(diags, are_weights = FALSE, ...) {
msg <- paste0(msg, " Mean does not exist, making empirical mean estimate of the draws not applicable.")
} else {
if (khat > khat_threshold) {
msg <- paste0(msg, "Sample size is too small, for given Pareto k-hat. Sample size larger than ", round(min_ss, 0), " is needed for reliable results.\n")
msg <- paste0(msg, " Sample size is too small, for given Pareto k-hat. Sample size larger than ", round(min_ss, 0), " is needed for reliable results.\n")
}
if (khat > 0.7) {
msg <- paste0(msg, " Bias dominates when k-hat > 0.7, making empirical mean estimate of the Pareto-smoothed draws unreliable.\n")
Expand All @@ -629,6 +632,6 @@ pareto_k_diagmsg <- function(diags, are_weights = FALSE, ...) {
msg <- paste0(msg, " Pareto khat for weights is high (", round(khat, 1) ,"). This indicates a single or few weights dominate.\n", "Inference based on weighted draws will be unreliable.\n")
}
}
message("Pareto k-hat = ", round(khat, 2), ". ", msg)
message("Pareto k-hat = ", round(khat, 2), ".", msg)
invisible(diags)
}
4 changes: 2 additions & 2 deletions R/weight_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ weights.draws <- function(object, log = FALSE, normalize = TRUE, ...) {

# validate weights and return log weights
validate_weights <- function(weights, draws, log = FALSE) {
checkmate::expect_numeric(weights)
checkmate::expect_flag(log)
checkmate::assert_numeric(weights)
checkmate::assert_flag(log)
if (length(weights) != ndraws(draws)) {
stop_no_call("Number of weights must match the number of draws.")
}
Expand Down
27 changes: 27 additions & 0 deletions man/pareto_diags.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 9 additions & 12 deletions man/pareto_smooth.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 16 additions & 18 deletions tests/testthat/test-pareto_smooth.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,26 +56,24 @@ test_that("pareto_khat diagnostics messages are as expected", {
)

expect_message(pareto_k_diagmsg(diags),
paste0('To halve the RMSE, approximately 4.1 times bigger S is needed.'))
paste0("Pareto k-hat = 0.5.\n"))

diags$khat <- 0.6

expect_message(pareto_k_diagmsg(diags),
paste0('S is too small, and sample size larger than 10 is needed for reliable results.\n'))
paste0("Pareto k-hat = 0.6. Sample size is too small, for given Pareto k-hat. Sample size larger than 10 is needed for reliable results.\n"))

diags$khat <- 0.71
diags$khat_threshold <- 0.8

expect_message(pareto_k_diagmsg(diags),
paste0('To halve the RMSE, approximately 4.1 times bigger S is needed.\n', 'Bias dominates RMSE, and the variance based MCSE is underestimated.\n'))
paste0("Pareto k-hat = 0.71. Bias dominates when k-hat > 0.7, making empirical mean estimate of the Pareto-smoothed draws unreliable.\n"))


diags$khat <- 1.1

expect_message(pareto_k_diagmsg(diags),
paste0('All estimates are unreliable. If the distribution of draws is bounded,\n',
'further draws may improve the estimates, but it is not possible to predict\n',
'whether any feasible sample size is sufficient.'))
paste0("Pareto k-hat = 1.1. Mean does not exist, making empirical mean estimate of the draws not applicable.\n"))

})

Expand Down Expand Up @@ -131,8 +129,8 @@ test_that("pareto_khat functions work with matrix with chains", {
expect_equal(pareto_khat(tau_chains, ndraws_tail = 20),
pareto_khat(tau_nochains, ndraws_tail = 20))

ps_chains <- pareto_smooth(tau_chains, ndraws_tail = 20)
ps_nochains <- pareto_smooth(tau_nochains, ndraws_tail = 20)
ps_chains <- pareto_smooth(tau_chains, ndraws_tail = 20, return_k = TRUE)
ps_nochains <- pareto_smooth(tau_nochains, ndraws_tail = 20, return_k = TRUE)

expect_equal(as.numeric(ps_chains$x), as.numeric(ps_nochains$x))

Expand All @@ -159,22 +157,22 @@ test_that("pareto_khat functions work with rvars with and without chains", {
expect_equal(pareto_diags(tau_rvar_chains, ndraws_tail = 20),
pareto_diags(tau_rvar_nochains, ndraws_tail = 20))

ps_chains <- pareto_smooth(tau_chains, ndraws_tail = 20)
ps_rvar_chains <- pareto_smooth(tau_rvar_chains, ndraws_tail = 20)
ps_chains <- pareto_smooth(tau_chains, ndraws_tail = 20, return_k = TRUE)
ps_rvar_chains <- pareto_smooth(tau_rvar_chains, ndraws_tail = 20, return_k = TRUE)

ps_nochains <- pareto_smooth(tau_nochains, ndraws_tail = 20)
ps_rvar_nochains <- pareto_smooth(tau_rvar_nochains, ndraws_tail = 20)
ps_nochains <- pareto_smooth(tau_nochains, ndraws_tail = 20, return_k = TRUE)
ps_rvar_nochains <- pareto_smooth(tau_rvar_nochains, ndraws_tail = 20, return_k = TRUE)

expect_equal(ps_rvar_chains$x, rvar(ps_chains$x, with_chains = TRUE))

expect_equal(ps_rvar_nochains$x, rvar(ps_nochains$x))


ps_chains <- pareto_smooth(tau_chains, ndraws_tail = 20, extra_diags = TRUE)
ps_rvar_chains <- pareto_smooth(tau_rvar_chains, ndraws_tail = 20, extra_diags = TRUE)
ps_chains <- pareto_smooth(tau_chains, ndraws_tail = 20, extra_diags = TRUE, return_k = TRUE)
ps_rvar_chains <- pareto_smooth(tau_rvar_chains, ndraws_tail = 20, extra_diags = TRUE, return_k = TRUE)

ps_nochains <- pareto_smooth(tau_nochains, ndraws_tail = 20, extra_diags = TRUE)
ps_rvar_nochains <- pareto_smooth(tau_rvar_nochains, ndraws_tail = 20, extra_diags = TRUE)
ps_nochains <- pareto_smooth(tau_nochains, ndraws_tail = 20, extra_diags = TRUE, return_k = TRUE)
ps_rvar_nochains <- pareto_smooth(tau_rvar_nochains, ndraws_tail = 20, extra_diags = TRUE, return_k = TRUE)

expect_equal(ps_rvar_chains$x, rvar(ps_chains$x, with_chains = TRUE))

Expand All @@ -185,7 +183,7 @@ test_that("pareto_khat functions work with rvars with and without chains", {
test_that("pareto_smooth returns x with smoothed tail", {
tau <- extract_variable_matrix(example_draws(), "tau")

tau_smoothed <- pareto_smooth(tau, ndraws_tail = 10, tail = "right")$x
tau_smoothed <- pareto_smooth(tau, ndraws_tail = 10, tail = "right", return_k = TRUE)$x

expect_equal(sort(tau)[1:390], sort(tau_smoothed)[1:390])

Expand All @@ -197,7 +195,7 @@ test_that("pareto_smooth works for log_weights", {
w <- c(1:25, 1e3, 1e3, 1e3)
lw <- log(w)

ps <- pareto_smooth(lw, are_log_weights = TRUE, verbose = FALSE, ndraws_tail = 10)
ps <- pareto_smooth(lw, are_log_weights = TRUE, verbose = FALSE, ndraws_tail = 10, return_k = TRUE)

# only right tail is smoothed
expect_equal(ps$x[1:15], lw[1:15])
Expand Down

0 comments on commit 4f4613c

Please sign in to comment.