Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add nested R-hat convergence diagnostic #303

Merged
merged 13 commits into from
Oct 29, 2023
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,8 @@ S3method(rhat,default)
S3method(rhat,rvar)
S3method(rhat_basic,default)
S3method(rhat_basic,rvar)
S3method(rhat_nested,default)
S3method(rhat_nested,rvar)
S3method(sd,default)
S3method(sd,rvar)
S3method(split_chains,draws)
Expand Down Expand Up @@ -466,6 +468,7 @@ export(reserved_variables)
export(rfun)
export(rhat)
export(rhat_basic)
export(rhat_nested)
export(rstar)
export(rvar)
export(rvar_all)
Expand Down
1 change: 1 addition & 0 deletions R/convergence.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#' | [mcse_sd()] | Monte Carlo standard error for standard deviations |
#' | [rhat_basic()] | Basic version of Rhat |
#' | [rhat()] | Improved, rank-based version of Rhat |
#' | [rhat_nested()] | Rhat for use with many short chains |
#' | [rstar()] | R* diagnostic |
#'
#' @return
Expand Down
104 changes: 104 additions & 0 deletions R/nested_rhat.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#' Nested Rhat convergence diagnostic
#'
#' Compute the Nested Rhat convergence diagnostic for a single variable
#' proposed in Margossian et al. (2023).
#'
#' @family diagnostics
#' @template args-conv
#' @param superchain_ids (numeric) Vector of length nchains specifying
#' which superchain each chain belongs to
#' @template args-methods-dots
#' @template return-conv
#' @template ref-margossian-nestedrhat-2023
#'
#' @examples
#' mu <- extract_variable_matrix(example_draws(), "mu")
#' rhat_nested(mu, superchain_ids = c(1, 1, 2, 2))
#'
#' d <- as_draws_rvars(example_draws("multi_normal"))
#' rhat_nested(d$Sigma, superchain_ids = c(1, 1, 2, 2))
#'
#' @export
rhat_nested <- function(x, ...) UseMethod("rhat_nested")

#' @rdname rhat_nested
#' @export
rhat_nested.default <- function(x, superchain_ids, ...) {
.rhat_nested(x, superchain_ids = superchain_ids)
}

#' @rdname rhat_nested
#' @export
rhat_nested.rvar <- function(x, superchain_ids, ...) {
summarise_rvar_by_element_with_chains(
x, rhat_nested, superchain_ids = superchain_ids, ...
)
}

.rhat_nested <- function(x, superchain_ids, ...) {
if (should_return_NA(x)) {
return(NA_real_)
}

x <- as.matrix(x)
niterations <- NROW(x)
nchains <- NCOL(x)

# check that all chains are assigned a superchain
if (length(superchain_ids) != nchains) {
warning_no_call("Length of superchain_ids not equal to number of chains, ",
"returning NA.")
return(NA_real_)
}

# check that superchains are equal length
superchain_id_table <- table(superchain_ids)
nchains_per_superchain <- max(superchain_id_table)

if (nchains_per_superchain != min(superchain_id_table)) {
warning_no_call("Number of chains per superchain is not the same for ",
"each superchain, returning NA.")
return(NA_real_)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically we could define nRhat to work on superchains with different sizes, but we didn't do this in the paper. I don't see a strong motivation for addressing this, but we can think about it in the future.

}

superchains <- unique(superchain_ids)

# mean and variance of chains calculated as in rhat
chain_mean <- matrixStats::colMeans2(x)
chain_var <- matrixStats::colVars(x, center = chain_mean)

# mean of superchains calculated by only including specified chains
# (equation 15 in Margossian et al. 2023)
superchain_mean <- sapply(
superchains, function(k) mean(x[, which(superchain_ids == k)])
)

# between-chain variance estimate (B_k in equation 18 in Margossian et al. 2023)
if (nchains_per_superchain == 1) {
var_between_chain <- 0
} else {
var_between_chain <- sapply(
superchains,
function(k) var(chain_mean[which(superchain_ids == k)])
)
}

# within-chain variance estimate (W_k in equation 18 in Margossian et al. 2023)
if (niterations == 1) {
var_within_chain <- 0
} else {
var_within_chain <- sapply(
superchains,
function(k) mean(chain_var[which(superchain_ids == k)])
)
}

# between-superchain variance (nB in equation 17 in Margossian et al. 2023)
var_between_superchain <- var(superchain_mean)

# within-superchain variance (nW in equation 18 in Margossian et al. 2023)
var_within_superchain <- mean(var_within_chain + var_between_chain)

# nested Rhat (nRhat in equation 19 in Margossian et al. 2023)
sqrt(1 + var_between_superchain / var_within_superchain)
}
5 changes: 5 additions & 0 deletions man-roxygen/ref-margossian-nestedrhat-2023.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#' @references
#' Charles C. Margossian, Matthew D. Hoffman, Pavel Sountsov, Lionel
#' Riou-Durand, Aki Vehtari and Andrew Gelman (2023). Nested R-hat:
#' Assessing the convergence of Markov chain Monte Carlo when running
#' many short chains. arxiv:arXiv:2110.13017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you're listing equation numbers and the paper is still under review, make sure to list which version of the preprint you're referencing (the latest version is v4)

1 change: 1 addition & 0 deletions man/diagnostics.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/ess_basic.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/ess_bulk.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/ess_quantile.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/ess_sd.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/ess_tail.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mcse_mean.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mcse_quantile.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mcse_sd.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/rhat.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/rhat_basic.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

75 changes: 75 additions & 0 deletions man/rhat_nested.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/rstar.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 31 additions & 0 deletions tests/testthat/test-rhat_nested.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
test_that("rhat_nested returns reasonable values", {
tau <- extract_variable_matrix(example_draws(), "tau")

nested_rhat <- rhat_nested(tau, superchain_ids = c(1, 2, 1, 2))
expect_true(nested_rhat > 1 & nested_rhat < 1.05)

nested_rhat <- rhat_nested(tau, superchain_ids = c(1, 2, 1, 2))
expect_true(nested_rhat > 1 & nested_rhat < 1.05)
})


test_that("rhat_nested handles special cases correctly", {
set.seed(1234)
x <- c(rnorm(10), NA)
expect_true(is.na(rhat_nested(x, superchain_ids = c(1))))

x <- c(rnorm(10), Inf)
expect_true(is.na(rhat_nested(x, superchain_ids = c(1, 2, 1, 2))))

tau <- extract_variable_matrix(example_draws(), "tau")
expect_warning(
rhat_nested(tau, superchain_ids = c(1, 1, 1, 3)),
"Number of chains per superchain is not the same for each superchain, returning NA."
)

tau <- extract_variable_matrix(example_draws(), "tau")
expect_warning(
rhat_nested(tau, superchain_ids = c(1, 2)),
"Length of superchain_ids not equal to number of chains, returning NA."
)
})
Loading