diff --git a/R/nested_rhat.R b/R/nested_rhat.R index f2eebff5..bfca55c6 100644 --- a/R/nested_rhat.R +++ b/R/nested_rhat.R @@ -16,7 +16,7 @@ #' rhat_nested(mu, superchain_ids = c(1,1,2,2)) #' #' d <- as_draws_rvars(example_draws("multi_normal")) -#' rhat(d$Sigma, superchain_ids = c(1,1,2,2)) +#' rhat_nested(d$Sigma, superchain_ids = c(1,1,2,2)) #' #' @export rhat_nested <- function(x, superchain_ids, ...) UseMethod("rhat_nested") @@ -33,38 +33,56 @@ rhat_nested.rvar <- function(x, superchain_ids, ...) { summarise_rvar_by_element_with_chains(x, rhat_nested, superchain_ids = superchain_ids, ...) } - .rhat_nested <- function(x, superchain_ids, ...) { - array_dims <- dim(x) - ndraws <- array_dims[1] + x <- as.matrix(x) + niterations <- NROW(x) nchains_per_superchain <- max(table(superchain_ids)) - nsuperchains <- length(unique(superchain_ids)) + superchains <- unique(superchain_ids) + + # mean and variance of chains calculated as in rhat + chain_mean <- matrixStats::colMeans2(x) + chain_var <- matrixStats::colVars(x, center = chain_mean) - superchain_mean <- sapply(unique(superchain_ids), function(k) mean(x[, which(superchain_ids == k)])) + # mean of superchains calculated by only including specified chains + # (equation 15 in Margossian et al. 2023) + superchain_mean <- sapply( + superchains, + function(k) mean(x[, which(superchain_ids == k)]) + ) - chain_mean <- matrix(matrixStats::colMeans2(x), nrow = 1) - chain_var <- matrixStats::colVars(x, center=chain_mean) - + # overall mean (as defined in equation 16 in Margossian et al. 2023) overall_mean <- mean(superchain_mean) + # between-chain variance estimate (B_k in equation 18 in Margossian et al. 2023) if (nchains_per_superchain == 1) { var_between_chain <- 0 } else { - var_between_chain <- sapply(unique(superchain_ids), function(k) var(chain_mean[, which(superchain_ids == k)])) + var_between_chain <- sapply( + superchains, + function(k) var(chain_mean[which(superchain_ids == k)]) + ) } - if (ndraws == 1) { + + # within-chain variance estimate (W_k in equation 18 in Margossian et al. 2023) + if (niterations == 1) { var_within_chain <- 0 } else { - var_within_chain <- sapply(unique(superchain_ids), function(k) mean(chain_var[which(superchain_ids == k)])) + var_within_chain <- sapply( + superchains, + function(k) mean(chain_var[which(superchain_ids == k)]) + ) } - + + # between-superchain variance (nB in equation 17 in Margossian et al. 2023) var_between_superchain <- matrixStats::colVars( as.matrix(superchain_mean), center = overall_mean ) - + + # within-superchain variance (nW in equation 18 in Margossian et al. 2023) var_within_superchain <- mean(var_within_chain + var_between_chain) - sqrt(1 + var_between_superchain / var_within_superchain) + # nested Rhat (nRhat in equation 19 in Margossian et al. 2023) + sqrt(1 + var_between_superchain / var_within_superchain) } diff --git a/man/rhat_nested.Rd b/man/rhat_nested.Rd index d741cd7c..5988b91a 100644 --- a/man/rhat_nested.Rd +++ b/man/rhat_nested.Rd @@ -49,7 +49,7 @@ mu <- extract_variable_matrix(example_draws(), "mu") rhat_nested(mu, superchain_ids = c(1,1,2,2)) d <- as_draws_rvars(example_draws("multi_normal")) -rhat(d$Sigma, superchain_ids = c(1,1,2,2)) +rhat_nested(d$Sigma, superchain_ids = c(1,1,2,2)) } \references{ diff --git a/tests/testthat/test-rhat_nested.R b/tests/testthat/test-rhat_nested.R new file mode 100644 index 00000000..1d71ea9e --- /dev/null +++ b/tests/testthat/test-rhat_nested.R @@ -0,0 +1,11 @@ +test_that("rhat_nested returns reasonable values", { + tau <- extract_variable_matrix(example_draws(), "tau") + + rhat <- rhat_nested(tau, superchain_ids = c(1, 2, 1, 2)) + expect_true(rhat > 0.99 & rhat < 1.05) + + rhat <- rhat_nested(tau, superchain_ids = c(1, 2, 1, 2)) + expect_true(rhat > 0.99 & rhat < 1.05) +}) + +