-
-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add nested R-hat convergence diagnostic #303
Changes from all commits
9d375b8
df28355
e02321d
3632a2f
9e5a44f
32cd38b
04f30ab
67b9e0b
86b64c8
17bf757
d521dae
2fb6270
32f97c4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
#' Nested Rhat convergence diagnostic | ||
#' | ||
#' Compute the nested Rhat convergence diagnostic for a single | ||
#' variable as proposed in Margossian et al. (2023). | ||
#' | ||
#' @family diagnostics | ||
#' @template args-conv | ||
#' @param superchain_ids (numeric) Vector of length nchains specifying | ||
#' which superchain each chain belongs to. There should be equal | ||
#' numbers of chains in each superchain. All chains within the same | ||
#' superchain are assumed to have been initialized at the same | ||
#' point. | ||
#' @template args-methods-dots | ||
#' | ||
#' @details Nested Rhat is a convergence diagnostic useful when | ||
#' running many short chains. It is calculated on superchains, which | ||
#' are groups of chains that have been initialized at the same | ||
#' point. | ||
#' | ||
#' Note that there is a slight difference in the calculation of Rhat | ||
#' and nested Rhat, as nested Rhat is lower bounded by 1. This means | ||
#' that nested Rhat with one chain per superchain will not be | ||
#' exactly equal to basic Rhat (see Footnote 1 in Margossian et | ||
#' al. (2023)). | ||
#' | ||
#' @template return-conv | ||
#' @template ref-margossian-nestedrhat-2023 | ||
#' | ||
#' @examples | ||
#' mu <- extract_variable_matrix(example_draws(), "mu") | ||
#' rhat_nested(mu, superchain_ids = c(1, 1, 2, 2)) | ||
#' | ||
#' d <- as_draws_rvars(example_draws("multi_normal")) | ||
#' rhat_nested(d$Sigma, superchain_ids = c(1, 1, 2, 2)) | ||
#' | ||
#' @export | ||
rhat_nested <- function(x, ...) UseMethod("rhat_nested") | ||
|
||
#' @rdname rhat_nested | ||
#' @export | ||
rhat_nested.default <- function(x, superchain_ids, ...) { | ||
.rhat_nested(x, superchain_ids = superchain_ids) | ||
} | ||
|
||
#' @rdname rhat_nested | ||
#' @export | ||
rhat_nested.rvar <- function(x, superchain_ids, ...) { | ||
summarise_rvar_by_element_with_chains( | ||
x, rhat_nested, superchain_ids = superchain_ids, ... | ||
) | ||
} | ||
|
||
.rhat_nested <- function(x, superchain_ids, ...) { | ||
if (should_return_NA(x)) { | ||
return(NA_real_) | ||
} | ||
|
||
x <- as.matrix(x) | ||
niterations <- NROW(x) | ||
nchains <- NCOL(x) | ||
|
||
# check that all chains are assigned a superchain | ||
if (length(superchain_ids) != nchains) { | ||
warning_no_call("Length of superchain_ids not equal to number of chains, ", | ||
"returning NA.") | ||
return(NA_real_) | ||
} | ||
|
||
# check that superchains are equal length | ||
superchain_id_table <- table(superchain_ids) | ||
nchains_per_superchain <- max(superchain_id_table) | ||
|
||
if (nchains_per_superchain != min(superchain_id_table)) { | ||
warning_no_call("Number of chains per superchain is not the same for ", | ||
"each superchain, returning NA.") | ||
return(NA_real_) | ||
} | ||
|
||
superchains <- unique(superchain_ids) | ||
|
||
# mean and variance of chains calculated as in rhat | ||
chain_mean <- matrixStats::colMeans2(x) | ||
chain_var <- matrixStats::colVars(x, center = chain_mean) | ||
|
||
# mean of superchains calculated by only including specified chains | ||
# (equation 15 in Margossian et al. 2023) | ||
superchain_mean <- sapply( | ||
superchains, function(k) mean(x[, which(superchain_ids == k)]) | ||
) | ||
|
||
# between-chain variance estimate (B_k in equation 18 in Margossian et al. 2023) | ||
if (nchains_per_superchain == 1) { | ||
var_between_chain <- 0 | ||
} else { | ||
var_between_chain <- sapply( | ||
superchains, | ||
function(k) var(chain_mean[which(superchain_ids == k)]) | ||
) | ||
} | ||
|
||
# within-chain variance estimate (W_k in equation 18 in Margossian et al. 2023) | ||
if (niterations == 1) { | ||
var_within_chain <- 0 | ||
} else { | ||
var_within_chain <- sapply( | ||
superchains, | ||
function(k) mean(chain_var[which(superchain_ids == k)]) | ||
) | ||
} | ||
|
||
# between-superchain variance (nB in equation 17 in Margossian et al. 2023) | ||
var_between_superchain <- var(superchain_mean) | ||
|
||
# within-superchain variance (nW in equation 18 in Margossian et al. 2023) | ||
var_within_superchain <- mean(var_within_chain + var_between_chain) | ||
|
||
# nested Rhat (nRhat in equation 19 in Margossian et al. 2023) | ||
sqrt(1 + var_between_superchain / var_within_superchain) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#' @references | ||
#' Charles C. Margossian, Matthew D. Hoffman, Pavel Sountsov, Lionel | ||
#' Riou-Durand, Aki Vehtari and Andrew Gelman (2023). Nested R-hat: | ||
#' Assessing the convergence of Markov chain Monte Carlo when running | ||
#' many short chains. arxiv:arXiv:2110.13017 (version 4) |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
test_that("rhat_nested returns reasonable values", { | ||
tau <- extract_variable_matrix(example_draws(), "tau") | ||
|
||
nested_rhat <- rhat_nested(tau, superchain_ids = c(1, 2, 1, 2)) | ||
expect_true(nested_rhat > 1 & nested_rhat < 1.05) | ||
|
||
nested_rhat <- rhat_nested(tau, superchain_ids = c(1, 2, 1, 2)) | ||
expect_true(nested_rhat > 1 & nested_rhat < 1.05) | ||
}) | ||
|
||
|
||
test_that("rhat_nested handles special cases correctly", { | ||
set.seed(1234) | ||
x <- c(rnorm(10), NA) | ||
expect_true(is.na(rhat_nested(x, superchain_ids = c(1)))) | ||
|
||
x <- c(rnorm(10), Inf) | ||
expect_true(is.na(rhat_nested(x, superchain_ids = c(1, 2, 1, 2)))) | ||
|
||
tau <- extract_variable_matrix(example_draws(), "tau") | ||
expect_warning( | ||
rhat_nested(tau, superchain_ids = c(1, 1, 1, 3)), | ||
"Number of chains per superchain is not the same for each superchain, returning NA." | ||
) | ||
|
||
tau <- extract_variable_matrix(example_draws(), "tau") | ||
expect_warning( | ||
rhat_nested(tau, superchain_ids = c(1, 2)), | ||
"Length of superchain_ids not equal to number of chains, returning NA." | ||
) | ||
}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically we could define nRhat to work on superchains with different sizes, but we didn't do this in the paper. I don't see a strong motivation for addressing this, but we can think about it in the future.