diff --git a/R/nested_rhat.R b/R/nested_rhat.R index c1470e59..f2eebff5 100644 --- a/R/nested_rhat.R +++ b/R/nested_rhat.R @@ -1,36 +1,3 @@ -.add_superchain_ids <- function(draws, superchain_ids) { - - # determine size of dims - chains_per_superchain <- table(superchain_ids) - num_chains_per_superchain <- max(chains_per_superchain) - num_iterations <- dim(draws)[1] - num_superchains <- max(superchain_ids) - - # create new empty array with correct dims - new_draws <- array( - NA, - dim = c( - num_iterations, - num_chains_per_superchain, - num_superchains) - ) - - # add dim names - dimnames(new_draws) <- list( - iteration = 1:num_iterations, - chain = 1:num_chains_per_superchain, - superchain = 1:num_superchains - ) - - # assign chains to superchains - for (k in 1:num_superchains) { - chains_in_superchain <- which(superchain_ids == k) - new_draws[, , k] <- draws[, chains_in_superchain] - } - - return(new_draws) -} - #' Nested Rhat convergence diagnostic #' #' Compute the Nested Rhat convergence diagnostic for a single variable @@ -57,37 +24,39 @@ rhat_nested <- function(x, superchain_ids, ...) UseMethod("rhat_nested") #' @rdname rhat_nested #' @export rhat_nested.default <- function(x, superchain_ids, ...) { + .rhat_nested(x, superchain_ids = superchain_ids) +} - x <- .add_superchain_ids(x, superchain_ids) - .rhat_nested(x) +#' @rdname rhat_nested +#' @export +rhat_nested.rvar <- function(x, superchain_ids, ...) { + summarise_rvar_by_element_with_chains(x, rhat_nested, superchain_ids = superchain_ids, ...) } -.rhat_nested <- function(x, ...) { + +.rhat_nested <- function(x, superchain_ids, ...) { array_dims <- dim(x) ndraws <- array_dims[1] - nchains <- array_dims[2] - nsuperchains <- array_dims[3] - - superchain_mean <- apply(x, 3, mean) - chain_mean <- apply(x, c(2, 3), mean) - chain_var <- apply(x, c(2, 3), var) + nchains_per_superchain <- max(table(superchain_ids)) + nsuperchains <- length(unique(superchain_ids)) + + superchain_mean <- sapply(unique(superchain_ids), function(k) mean(x[, which(superchain_ids == k)])) + + chain_mean <- matrix(matrixStats::colMeans2(x), nrow = 1) + chain_var <- matrixStats::colVars(x, center=chain_mean) overall_mean <- mean(superchain_mean) - if (nchains == 1) { + if (nchains_per_superchain == 1) { var_between_chain <- 0 } else { - var_between_chain <- matrixStats::colVars( - chain_mean, - center = superchain_mean - ) + var_between_chain <- sapply(unique(superchain_ids), function(k) var(chain_mean[, which(superchain_ids == k)])) } - if (ndraws == 1) { var_within_chain <- 0 } else { - var_within_chain <- colMeans(chain_var) + var_within_chain <- sapply(unique(superchain_ids), function(k) mean(chain_var[which(superchain_ids == k)])) } var_between_superchain <- matrixStats::colVars( @@ -97,5 +66,5 @@ rhat_nested.default <- function(x, superchain_ids, ...) { var_within_superchain <- mean(var_within_chain + var_between_chain) - sqrt(1 + var_between_superchain / var_within_superchain) + sqrt(1 + var_between_superchain / var_within_superchain) }