-
-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
Add nested R-hat convergence diagnostic
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
#' Nested Rhat convergence diagnostic | ||
#' | ||
#' Compute the nested Rhat convergence diagnostic for a single | ||
#' variable as proposed in Margossian et al. (2023). | ||
#' | ||
#' @family diagnostics | ||
#' @template args-conv | ||
#' @param superchain_ids (numeric) Vector of length nchains specifying | ||
#' which superchain each chain belongs to. There should be equal | ||
#' numbers of chains in each superchain. All chains within the same | ||
#' superchain are assumed to have been initialized at the same | ||
#' point. | ||
#' @template args-methods-dots | ||
#' | ||
#' @details Nested Rhat is a convergence diagnostic useful when | ||
#' running many short chains. It is calculated on superchains, which | ||
#' are groups of chains that have been initialized at the same | ||
#' point. | ||
#' | ||
#' Note that there is a slight difference in the calculation of Rhat | ||
#' and nested Rhat, as nested Rhat is lower bounded by 1. This means | ||
#' that nested Rhat with one chain per superchain will not be | ||
#' exactly equal to basic Rhat (see Footnote 1 in Margossian et | ||
#' al. (2023)). | ||
#' | ||
#' @template return-conv | ||
#' @template ref-margossian-nestedrhat-2023 | ||
#' | ||
#' @examples | ||
#' mu <- extract_variable_matrix(example_draws(), "mu") | ||
#' rhat_nested(mu, superchain_ids = c(1, 1, 2, 2)) | ||
#' | ||
#' d <- as_draws_rvars(example_draws("multi_normal")) | ||
#' rhat_nested(d$Sigma, superchain_ids = c(1, 1, 2, 2)) | ||
#' | ||
#' @export | ||
rhat_nested <- function(x, ...) UseMethod("rhat_nested") | ||
|
||
#' @rdname rhat_nested | ||
#' @export | ||
rhat_nested.default <- function(x, superchain_ids, ...) { | ||
.rhat_nested(x, superchain_ids = superchain_ids) | ||
} | ||
|
||
#' @rdname rhat_nested | ||
#' @export | ||
rhat_nested.rvar <- function(x, superchain_ids, ...) { | ||
summarise_rvar_by_element_with_chains( | ||
x, rhat_nested, superchain_ids = superchain_ids, ... | ||
) | ||
} | ||
|
||
.rhat_nested <- function(x, superchain_ids, ...) { | ||
if (should_return_NA(x)) { | ||
return(NA_real_) | ||
} | ||
|
||
x <- as.matrix(x) | ||
niterations <- NROW(x) | ||
nchains <- NCOL(x) | ||
|
||
# check that all chains are assigned a superchain | ||
if (length(superchain_ids) != nchains) { | ||
warning_no_call("Length of superchain_ids not equal to number of chains, ", | ||
"returning NA.") | ||
return(NA_real_) | ||
} | ||
|
||
# check that superchains are equal length | ||
superchain_id_table <- table(superchain_ids) | ||
nchains_per_superchain <- max(superchain_id_table) | ||
|
||
if (nchains_per_superchain != min(superchain_id_table)) { | ||
warning_no_call("Number of chains per superchain is not the same for ", | ||
"each superchain, returning NA.") | ||
return(NA_real_) | ||
} | ||
|
||
superchains <- unique(superchain_ids) | ||
|
||
# mean and variance of chains calculated as in rhat | ||
chain_mean <- matrixStats::colMeans2(x) | ||
chain_var <- matrixStats::colVars(x, center = chain_mean) | ||
|
||
# mean of superchains calculated by only including specified chains | ||
# (equation 15 in Margossian et al. 2023) | ||
superchain_mean <- sapply( | ||
superchains, function(k) mean(x[, which(superchain_ids == k)]) | ||
) | ||
|
||
# between-chain variance estimate (B_k in equation 18 in Margossian et al. 2023) | ||
if (nchains_per_superchain == 1) { | ||
var_between_chain <- 0 | ||
} else { | ||
var_between_chain <- sapply( | ||
superchains, | ||
function(k) var(chain_mean[which(superchain_ids == k)]) | ||
) | ||
} | ||
|
||
# within-chain variance estimate (W_k in equation 18 in Margossian et al. 2023) | ||
if (niterations == 1) { | ||
var_within_chain <- 0 | ||
} else { | ||
var_within_chain <- sapply( | ||
superchains, | ||
function(k) mean(chain_var[which(superchain_ids == k)]) | ||
) | ||
} | ||
|
||
# between-superchain variance (nB in equation 17 in Margossian et al. 2023) | ||
var_between_superchain <- var(superchain_mean) | ||
|
||
# within-superchain variance (nW in equation 18 in Margossian et al. 2023) | ||
var_within_superchain <- mean(var_within_chain + var_between_chain) | ||
|
||
# nested Rhat (nRhat in equation 19 in Margossian et al. 2023) | ||
sqrt(1 + var_between_superchain / var_within_superchain) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#' @references | ||
#' Charles C. Margossian, Matthew D. Hoffman, Pavel Sountsov, Lionel | ||
#' Riou-Durand, Aki Vehtari and Andrew Gelman (2023). Nested R-hat: | ||
#' Assessing the convergence of Markov chain Monte Carlo when running | ||
#' many short chains. arxiv:arXiv:2110.13017 (version 4) |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
test_that("rhat_nested returns reasonable values", { | ||
tau <- extract_variable_matrix(example_draws(), "tau") | ||
|
||
nested_rhat <- rhat_nested(tau, superchain_ids = c(1, 2, 1, 2)) | ||
expect_true(nested_rhat > 1 & nested_rhat < 1.05) | ||
|
||
nested_rhat <- rhat_nested(tau, superchain_ids = c(1, 2, 1, 2)) | ||
expect_true(nested_rhat > 1 & nested_rhat < 1.05) | ||
}) | ||
|
||
|
||
test_that("rhat_nested handles special cases correctly", { | ||
set.seed(1234) | ||
x <- c(rnorm(10), NA) | ||
expect_true(is.na(rhat_nested(x, superchain_ids = c(1)))) | ||
|
||
x <- c(rnorm(10), Inf) | ||
expect_true(is.na(rhat_nested(x, superchain_ids = c(1, 2, 1, 2)))) | ||
|
||
tau <- extract_variable_matrix(example_draws(), "tau") | ||
expect_warning( | ||
rhat_nested(tau, superchain_ids = c(1, 1, 1, 3)), | ||
"Number of chains per superchain is not the same for each superchain, returning NA." | ||
) | ||
|
||
tau <- extract_variable_matrix(example_draws(), "tau") | ||
expect_warning( | ||
rhat_nested(tau, superchain_ids = c(1, 2)), | ||
"Length of superchain_ids not equal to number of chains, returning NA." | ||
) | ||
}) |