Skip to content

Commit

Permalink
improve memory efficiency of rhat_nested by not creating a new array
Browse files Browse the repository at this point in the history
  • Loading branch information
n-kall committed Oct 13, 2023
1 parent df28355 commit e02321d
Showing 1 changed file with 19 additions and 50 deletions.
69 changes: 19 additions & 50 deletions R/nested_rhat.R
Original file line number Diff line number Diff line change
@@ -1,36 +1,3 @@
.add_superchain_ids <- function(draws, superchain_ids) {

# determine size of dims
chains_per_superchain <- table(superchain_ids)
num_chains_per_superchain <- max(chains_per_superchain)
num_iterations <- dim(draws)[1]
num_superchains <- max(superchain_ids)

# create new empty array with correct dims
new_draws <- array(
NA,
dim = c(
num_iterations,
num_chains_per_superchain,
num_superchains)
)

# add dim names
dimnames(new_draws) <- list(
iteration = 1:num_iterations,
chain = 1:num_chains_per_superchain,
superchain = 1:num_superchains
)

# assign chains to superchains
for (k in 1:num_superchains) {
chains_in_superchain <- which(superchain_ids == k)
new_draws[, , k] <- draws[, chains_in_superchain]
}

return(new_draws)
}

#' Nested Rhat convergence diagnostic
#'
#' Compute the Nested Rhat convergence diagnostic for a single variable
Expand All @@ -57,37 +24,39 @@ rhat_nested <- function(x, superchain_ids, ...) UseMethod("rhat_nested")
#' @rdname rhat_nested
#' @export
rhat_nested.default <- function(x, superchain_ids, ...) {
.rhat_nested(x, superchain_ids = superchain_ids)
}

x <- .add_superchain_ids(x, superchain_ids)
.rhat_nested(x)
#' @rdname rhat_nested
#' @export
rhat_nested.rvar <- function(x, superchain_ids, ...) {
summarise_rvar_by_element_with_chains(x, rhat_nested, superchain_ids = superchain_ids, ...)
}

.rhat_nested <- function(x, ...) {

.rhat_nested <- function(x, superchain_ids, ...) {

array_dims <- dim(x)
ndraws <- array_dims[1]
nchains <- array_dims[2]
nsuperchains <- array_dims[3]

superchain_mean <- apply(x, 3, mean)
chain_mean <- apply(x, c(2, 3), mean)
chain_var <- apply(x, c(2, 3), var)
nchains_per_superchain <- max(table(superchain_ids))
nsuperchains <- length(unique(superchain_ids))

superchain_mean <- sapply(unique(superchain_ids), function(k) mean(x[, which(superchain_ids == k)]))

chain_mean <- matrix(matrixStats::colMeans2(x), nrow = 1)
chain_var <- matrixStats::colVars(x, center=chain_mean)

overall_mean <- mean(superchain_mean)

if (nchains == 1) {
if (nchains_per_superchain == 1) {
var_between_chain <- 0
} else {
var_between_chain <- matrixStats::colVars(
chain_mean,
center = superchain_mean
)
var_between_chain <- sapply(unique(superchain_ids), function(k) var(chain_mean[, which(superchain_ids == k)]))
}

if (ndraws == 1) {
var_within_chain <- 0
} else {
var_within_chain <- colMeans(chain_var)
var_within_chain <- sapply(unique(superchain_ids), function(k) mean(chain_var[which(superchain_ids == k)]))
}

var_between_superchain <- matrixStats::colVars(
Expand All @@ -97,5 +66,5 @@ rhat_nested.default <- function(x, superchain_ids, ...) {

var_within_superchain <- mean(var_within_chain + var_between_chain)

sqrt(1 + var_between_superchain / var_within_superchain)
sqrt(1 + var_between_superchain / var_within_superchain)
}

0 comments on commit e02321d

Please sign in to comment.