-
-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add nested R-hat convergence diagnostic #303
Changes from 10 commits
9d375b8
df28355
e02321d
3632a2f
9e5a44f
32cd38b
04f30ab
67b9e0b
86b64c8
17bf757
d521dae
2fb6270
32f97c4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
#' Nested Rhat convergence diagnostic | ||
#' | ||
#' Compute the Nested Rhat convergence diagnostic for a single variable | ||
#' proposed in Margossian et al. (2023). | ||
#' | ||
#' @family diagnostics | ||
#' @template args-conv | ||
#' @param superchain_ids (numeric) Vector of length nchains specifying | ||
#' which superchain each chain belongs to | ||
#' @template args-methods-dots | ||
#' @template return-conv | ||
#' @template ref-margossian-nestedrhat-2023 | ||
#' | ||
#' @examples | ||
#' mu <- extract_variable_matrix(example_draws(), "mu") | ||
#' rhat_nested(mu, superchain_ids = c(1, 1, 2, 2)) | ||
#' | ||
#' d <- as_draws_rvars(example_draws("multi_normal")) | ||
#' rhat_nested(d$Sigma, superchain_ids = c(1, 1, 2, 2)) | ||
#' | ||
#' @export | ||
rhat_nested <- function(x, ...) UseMethod("rhat_nested") | ||
|
||
#' @rdname rhat_nested | ||
#' @export | ||
rhat_nested.default <- function(x, superchain_ids, ...) { | ||
.rhat_nested(x, superchain_ids = superchain_ids) | ||
} | ||
|
||
#' @rdname rhat_nested | ||
#' @export | ||
rhat_nested.rvar <- function(x, superchain_ids, ...) { | ||
summarise_rvar_by_element_with_chains( | ||
x, rhat_nested, superchain_ids = superchain_ids, ... | ||
) | ||
} | ||
|
||
.rhat_nested <- function(x, superchain_ids, ...) { | ||
if (should_return_NA(x)) { | ||
return(NA_real_) | ||
} | ||
|
||
x <- as.matrix(x) | ||
niterations <- NROW(x) | ||
nchains <- NCOL(x) | ||
|
||
# check that all chains are assigned a superchain | ||
if (length(superchain_ids) != nchains) { | ||
warning_no_call("Length of superchain_ids not equal to number of chains, ", | ||
"returning NA.") | ||
return(NA_real_) | ||
} | ||
|
||
# check that superchains are equal length | ||
superchain_id_table <- table(superchain_ids) | ||
nchains_per_superchain <- max(superchain_id_table) | ||
|
||
if (nchains_per_superchain != min(superchain_id_table)) { | ||
warning_no_call("Number of chains per superchain is not the same for ", | ||
"each superchain, returning NA.") | ||
return(NA_real_) | ||
} | ||
|
||
superchains <- unique(superchain_ids) | ||
|
||
# mean and variance of chains calculated as in rhat | ||
chain_mean <- matrixStats::colMeans2(x) | ||
chain_var <- matrixStats::colVars(x, center = chain_mean) | ||
|
||
# mean of superchains calculated by only including specified chains | ||
# (equation 15 in Margossian et al. 2023) | ||
superchain_mean <- sapply( | ||
superchains, function(k) mean(x[, which(superchain_ids == k)]) | ||
) | ||
|
||
# between-chain variance estimate (B_k in equation 18 in Margossian et al. 2023) | ||
if (nchains_per_superchain == 1) { | ||
var_between_chain <- 0 | ||
} else { | ||
var_between_chain <- sapply( | ||
superchains, | ||
function(k) var(chain_mean[which(superchain_ids == k)]) | ||
) | ||
} | ||
|
||
# within-chain variance estimate (W_k in equation 18 in Margossian et al. 2023) | ||
if (niterations == 1) { | ||
var_within_chain <- 0 | ||
} else { | ||
var_within_chain <- sapply( | ||
superchains, | ||
function(k) mean(chain_var[which(superchain_ids == k)]) | ||
) | ||
} | ||
|
||
# between-superchain variance (nB in equation 17 in Margossian et al. 2023) | ||
var_between_superchain <- var(superchain_mean) | ||
|
||
# within-superchain variance (nW in equation 18 in Margossian et al. 2023) | ||
var_within_superchain <- mean(var_within_chain + var_between_chain) | ||
|
||
# nested Rhat (nRhat in equation 19 in Margossian et al. 2023) | ||
sqrt(1 + var_between_superchain / var_within_superchain) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#' @references | ||
#' Charles C. Margossian, Matthew D. Hoffman, Pavel Sountsov, Lionel | ||
#' Riou-Durand, Aki Vehtari and Andrew Gelman (2023). Nested R-hat: | ||
#' Assessing the convergence of Markov chain Monte Carlo when running | ||
#' many short chains. arxiv:arXiv:2110.13017 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since you're listing equation numbers and the paper is still under review, make sure to list which version of the preprint you're referencing (the latest version is v4) |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
test_that("rhat_nested returns reasonable values", { | ||
tau <- extract_variable_matrix(example_draws(), "tau") | ||
|
||
nested_rhat <- rhat_nested(tau, superchain_ids = c(1, 2, 1, 2)) | ||
expect_true(nested_rhat > 1 & nested_rhat < 1.05) | ||
|
||
nested_rhat <- rhat_nested(tau, superchain_ids = c(1, 2, 1, 2)) | ||
expect_true(nested_rhat > 1 & nested_rhat < 1.05) | ||
}) | ||
|
||
|
||
test_that("rhat_nested handles special cases correctly", { | ||
set.seed(1234) | ||
x <- c(rnorm(10), NA) | ||
expect_true(is.na(rhat_nested(x, superchain_ids = c(1)))) | ||
|
||
x <- c(rnorm(10), Inf) | ||
expect_true(is.na(rhat_nested(x, superchain_ids = c(1, 2, 1, 2)))) | ||
|
||
tau <- extract_variable_matrix(example_draws(), "tau") | ||
expect_warning( | ||
rhat_nested(tau, superchain_ids = c(1, 1, 1, 3)), | ||
"Number of chains per superchain is not the same for each superchain, returning NA." | ||
) | ||
|
||
tau <- extract_variable_matrix(example_draws(), "tau") | ||
expect_warning( | ||
rhat_nested(tau, superchain_ids = c(1, 2)), | ||
"Length of superchain_ids not equal to number of chains, returning NA." | ||
) | ||
}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically we could define nRhat to work on superchains with different sizes, but we didn't do this in the paper. I don't see a strong motivation for addressing this, but we can think about it in the future.