Skip to content

Commit

Permalink
Merge pull request #303 from n-kall/nested_rhat
Browse files Browse the repository at this point in the history
Add nested R-hat convergence diagnostic
  • Loading branch information
paul-buerkner authored Oct 29, 2023
2 parents a362feb + 32f97c4 commit 6add3e6
Show file tree
Hide file tree
Showing 18 changed files with 261 additions and 0 deletions.
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,8 @@ S3method(rhat,default)
S3method(rhat,rvar)
S3method(rhat_basic,default)
S3method(rhat_basic,rvar)
S3method(rhat_nested,default)
S3method(rhat_nested,rvar)
S3method(sd,default)
S3method(sd,rvar)
S3method(split_chains,draws)
Expand Down Expand Up @@ -466,6 +468,7 @@ export(reserved_variables)
export(rfun)
export(rhat)
export(rhat_basic)
export(rhat_nested)
export(rstar)
export(rvar)
export(rvar_all)
Expand Down
1 change: 1 addition & 0 deletions R/convergence.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#' | [mcse_sd()] | Monte Carlo standard error for standard deviations |
#' | [rhat_basic()] | Basic version of Rhat |
#' | [rhat()] | Improved, rank-based version of Rhat |
#' | [rhat_nested()] | Rhat for use with many short chains |
#' | [rstar()] | R* diagnostic |
#'
#' @return
Expand Down
119 changes: 119 additions & 0 deletions R/nested_rhat.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#' Nested Rhat convergence diagnostic
#'
#' Compute the nested Rhat convergence diagnostic for a single
#' variable as proposed in Margossian et al. (2023).
#'
#' @family diagnostics
#' @template args-conv
#' @param superchain_ids (numeric) Vector of length nchains specifying
#' which superchain each chain belongs to. There should be equal
#' numbers of chains in each superchain. All chains within the same
#' superchain are assumed to have been initialized at the same
#' point.
#' @template args-methods-dots
#'
#' @details Nested Rhat is a convergence diagnostic useful when
#' running many short chains. It is calculated on superchains, which
#' are groups of chains that have been initialized at the same
#' point.
#'
#' Note that there is a slight difference in the calculation of Rhat
#' and nested Rhat, as nested Rhat is lower bounded by 1. This means
#' that nested Rhat with one chain per superchain will not be
#' exactly equal to basic Rhat (see Footnote 1 in Margossian et
#' al. (2023)).
#'
#' @template return-conv
#' @template ref-margossian-nestedrhat-2023
#'
#' @examples
#' mu <- extract_variable_matrix(example_draws(), "mu")
#' rhat_nested(mu, superchain_ids = c(1, 1, 2, 2))
#'
#' d <- as_draws_rvars(example_draws("multi_normal"))
#' rhat_nested(d$Sigma, superchain_ids = c(1, 1, 2, 2))
#'
#' @export
rhat_nested <- function(x, ...) UseMethod("rhat_nested")

#' @rdname rhat_nested
#' @export
rhat_nested.default <- function(x, superchain_ids, ...) {
.rhat_nested(x, superchain_ids = superchain_ids)
}

#' @rdname rhat_nested
#' @export
rhat_nested.rvar <- function(x, superchain_ids, ...) {
summarise_rvar_by_element_with_chains(
x, rhat_nested, superchain_ids = superchain_ids, ...
)
}

.rhat_nested <- function(x, superchain_ids, ...) {
if (should_return_NA(x)) {
return(NA_real_)
}

x <- as.matrix(x)
niterations <- NROW(x)
nchains <- NCOL(x)

# check that all chains are assigned a superchain
if (length(superchain_ids) != nchains) {
warning_no_call("Length of superchain_ids not equal to number of chains, ",
"returning NA.")
return(NA_real_)
}

# check that superchains are equal length
superchain_id_table <- table(superchain_ids)
nchains_per_superchain <- max(superchain_id_table)

if (nchains_per_superchain != min(superchain_id_table)) {
warning_no_call("Number of chains per superchain is not the same for ",
"each superchain, returning NA.")
return(NA_real_)
}

superchains <- unique(superchain_ids)

# mean and variance of chains calculated as in rhat
chain_mean <- matrixStats::colMeans2(x)
chain_var <- matrixStats::colVars(x, center = chain_mean)

# mean of superchains calculated by only including specified chains
# (equation 15 in Margossian et al. 2023)
superchain_mean <- sapply(
superchains, function(k) mean(x[, which(superchain_ids == k)])
)

# between-chain variance estimate (B_k in equation 18 in Margossian et al. 2023)
if (nchains_per_superchain == 1) {
var_between_chain <- 0
} else {
var_between_chain <- sapply(
superchains,
function(k) var(chain_mean[which(superchain_ids == k)])
)
}

# within-chain variance estimate (W_k in equation 18 in Margossian et al. 2023)
if (niterations == 1) {
var_within_chain <- 0
} else {
var_within_chain <- sapply(
superchains,
function(k) mean(chain_var[which(superchain_ids == k)])
)
}

# between-superchain variance (nB in equation 17 in Margossian et al. 2023)
var_between_superchain <- var(superchain_mean)

# within-superchain variance (nW in equation 18 in Margossian et al. 2023)
var_within_superchain <- mean(var_within_chain + var_between_chain)

# nested Rhat (nRhat in equation 19 in Margossian et al. 2023)
sqrt(1 + var_between_superchain / var_within_superchain)
}
5 changes: 5 additions & 0 deletions man-roxygen/ref-margossian-nestedrhat-2023.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#' @references
#' Charles C. Margossian, Matthew D. Hoffman, Pavel Sountsov, Lionel
#' Riou-Durand, Aki Vehtari and Andrew Gelman (2023). Nested R-hat:
#' Assessing the convergence of Markov chain Monte Carlo when running
#' many short chains. arxiv:arXiv:2110.13017 (version 4)
1 change: 1 addition & 0 deletions man/diagnostics.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/ess_basic.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/ess_bulk.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/ess_quantile.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/ess_sd.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/ess_tail.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mcse_mean.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mcse_quantile.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mcse_sd.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/rhat.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/rhat_basic.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

90 changes: 90 additions & 0 deletions man/rhat_nested.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/rstar.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 31 additions & 0 deletions tests/testthat/test-rhat_nested.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
test_that("rhat_nested returns reasonable values", {
tau <- extract_variable_matrix(example_draws(), "tau")

nested_rhat <- rhat_nested(tau, superchain_ids = c(1, 2, 1, 2))
expect_true(nested_rhat > 1 & nested_rhat < 1.05)

nested_rhat <- rhat_nested(tau, superchain_ids = c(1, 2, 1, 2))
expect_true(nested_rhat > 1 & nested_rhat < 1.05)
})


test_that("rhat_nested handles special cases correctly", {
set.seed(1234)
x <- c(rnorm(10), NA)
expect_true(is.na(rhat_nested(x, superchain_ids = c(1))))

x <- c(rnorm(10), Inf)
expect_true(is.na(rhat_nested(x, superchain_ids = c(1, 2, 1, 2))))

tau <- extract_variable_matrix(example_draws(), "tau")
expect_warning(
rhat_nested(tau, superchain_ids = c(1, 1, 1, 3)),
"Number of chains per superchain is not the same for each superchain, returning NA."
)

tau <- extract_variable_matrix(example_draws(), "tau")
expect_warning(
rhat_nested(tau, superchain_ids = c(1, 2)),
"Length of superchain_ids not equal to number of chains, returning NA."
)
})

0 comments on commit 6add3e6

Please sign in to comment.