Skip to content

Commit

Permalink
cleanup nested rhat and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
n-kall committed Oct 13, 2023
1 parent 3632a2f commit 9e5a44f
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 16 deletions.
48 changes: 33 additions & 15 deletions R/nested_rhat.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#' rhat_nested(mu, superchain_ids = c(1,1,2,2))
#'
#' d <- as_draws_rvars(example_draws("multi_normal"))
#' rhat(d$Sigma, superchain_ids = c(1,1,2,2))
#' rhat_nested(d$Sigma, superchain_ids = c(1,1,2,2))
#'
#' @export
rhat_nested <- function(x, superchain_ids, ...) UseMethod("rhat_nested")
Expand All @@ -33,38 +33,56 @@ rhat_nested.rvar <- function(x, superchain_ids, ...) {
summarise_rvar_by_element_with_chains(x, rhat_nested, superchain_ids = superchain_ids, ...)
}


.rhat_nested <- function(x, superchain_ids, ...) {

array_dims <- dim(x)
ndraws <- array_dims[1]
x <- as.matrix(x)
niterations <- NROW(x)
nchains_per_superchain <- max(table(superchain_ids))
nsuperchains <- length(unique(superchain_ids))
superchains <- unique(superchain_ids)

# mean and variance of chains calculated as in rhat
chain_mean <- matrixStats::colMeans2(x)
chain_var <- matrixStats::colVars(x, center = chain_mean)

superchain_mean <- sapply(unique(superchain_ids), function(k) mean(x[, which(superchain_ids == k)]))
# mean of superchains calculated by only including specified chains
# (equation 15 in Margossian et al. 2023)
superchain_mean <- sapply(
superchains,
function(k) mean(x[, which(superchain_ids == k)])
)

chain_mean <- matrix(matrixStats::colMeans2(x), nrow = 1)
chain_var <- matrixStats::colVars(x, center=chain_mean)

# overall mean (as defined in equation 16 in Margossian et al. 2023)
overall_mean <- mean(superchain_mean)

# between-chain variance estimate (B_k in equation 18 in Margossian et al. 2023)
if (nchains_per_superchain == 1) {
var_between_chain <- 0
} else {
var_between_chain <- sapply(unique(superchain_ids), function(k) var(chain_mean[, which(superchain_ids == k)]))
var_between_chain <- sapply(
superchains,
function(k) var(chain_mean[which(superchain_ids == k)])
)
}
if (ndraws == 1) {

# within-chain variance estimate (W_k in equation 18 in Margossian et al. 2023)
if (niterations == 1) {
var_within_chain <- 0
} else {
var_within_chain <- sapply(unique(superchain_ids), function(k) mean(chain_var[which(superchain_ids == k)]))
var_within_chain <- sapply(
superchains,
function(k) mean(chain_var[which(superchain_ids == k)])
)
}


# between-superchain variance (nB in equation 17 in Margossian et al. 2023)
var_between_superchain <- matrixStats::colVars(
as.matrix(superchain_mean),
center = overall_mean
)


# within-superchain variance (nW in equation 18 in Margossian et al. 2023)
var_within_superchain <- mean(var_within_chain + var_between_chain)

sqrt(1 + var_between_superchain / var_within_superchain)
# nested Rhat (nRhat in equation 19 in Margossian et al. 2023)
sqrt(1 + var_between_superchain / var_within_superchain)
}
2 changes: 1 addition & 1 deletion man/rhat_nested.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions tests/testthat/test-rhat_nested.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
test_that("rhat_nested returns reasonable values", {
tau <- extract_variable_matrix(example_draws(), "tau")

rhat <- rhat_nested(tau, superchain_ids = c(1, 2, 1, 2))
expect_true(rhat > 0.99 & rhat < 1.05)

rhat <- rhat_nested(tau, superchain_ids = c(1, 2, 1, 2))
expect_true(rhat > 0.99 & rhat < 1.05)
})


0 comments on commit 9e5a44f

Please sign in to comment.