Skip to content

Commit

Permalink
add input checks and tests for rhat_nested
Browse files Browse the repository at this point in the history
  • Loading branch information
n-kall committed Oct 25, 2023
1 parent 9e5a44f commit 32cd38b
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
24 changes: 23 additions & 1 deletion R/nested_rhat.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,34 @@ rhat_nested.rvar <- function(x, superchain_ids, ...) {
}

.rhat_nested <- function(x, superchain_ids, ...) {
if (should_return_NA(x)) {
return(NA_real_)
}

x <- as.matrix(x)
niterations <- NROW(x)
nchains_per_superchain <- max(table(superchain_ids))
nchains <- NCOL(x)


# check that all chains are assigned a superchain
if (length(superchain_ids) != nchains) {
warning_no_call("Length of superchain_ids not equal to number of chains, returning NA.")
return(NA_real_)
}


# check that superchains are equal length
superchain_id_table <- table(superchain_ids)
nchains_per_superchain <- max(superchain_id_table)

if (nchains_per_superchain != min(superchain_id_table)) {
warning_no_call("Number of chains per superchain is not the same for each superchain, returning NA.")
return(NA_real_)
}

superchains <- unique(superchain_ids)


# mean and variance of chains calculated as in rhat
chain_mean <- matrixStats::colMeans2(x)
chain_var <- matrixStats::colVars(x, center = chain_mean)
Expand Down
20 changes: 20 additions & 0 deletions tests/testthat/test-rhat_nested.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,23 @@ test_that("rhat_nested returns reasonable values", {
})


test_that("rhat_nested handles special cases correctly", {
set.seed(1234)
x <- c(rnorm(10), NA)
expect_true(is.na(rhat_nested(x, superchain_ids = c(1))))

x <- c(rnorm(10), Inf)
expect_true(is.na(rhat_nested(x, superchain_ids = c(1,2,1,2))))

tau <- extract_variable_matrix(example_draws(), "tau")
expect_warning(
rhat_nested(tau, superchain_ids = c(1,1,1,3)),
"Number of chains per superchain is not the same for each superchain, returning NA."
)

tau <- extract_variable_matrix(example_draws(), "tau")
expect_warning(
rhat_nested(tau, superchain_ids = c(1,2)),
"Length of superchain_ids not equal to number of chains, returning NA."
)
})

0 comments on commit 32cd38b

Please sign in to comment.