Skip to content

Commit

Permalink
Merge pull request #306 from stan-dev/issue-300
Browse files Browse the repository at this point in the history
Regenerate draw ids for bind_draws(<draws_rvar>, along = "chain")
  • Loading branch information
paul-buerkner authored Oct 31, 2023
2 parents 7868fcf + e551e5d commit 5e0a269
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 1 deletion.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
and `vctrs::vec_proxy_order()`.
* Minor future-proofing of `cbind(<rvar>)`, `rbind(<rvar>)`, and `chol(<rvar>)`
for R 4.4 (#304).
* Ensure that `bind_draws(<draws_rvars>)` regenerates draw ids when binding along
chains or draws; this also fixes a bug in `split_chains(<draws_rvars>)` (#300).


# posterior 1.4.1
Expand Down
6 changes: 5 additions & 1 deletion R/bind_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,11 @@ bind_draws.draws_rvars <- function(x, ..., along = "variable") {
out <- lapply(seq_along(dots[[1]]), function(var_i) {
vars <- lapply(dots, `[[`, var_i)
var_draws <- lapply(vars, draws_of)
out <- rvar(abind(var_draws, along = 1), nchains = nchains)
new_draws <- abind(var_draws, along = 1)
# must regenerate draw ids in case binding along draws or chains generates
# duplicate or out of sequence draw ids
rownames(new_draws) <- seq_rows(new_draws)
out <- rvar(new_draws, nchains = nchains)
out
})
names(out) <- names(dots[[1]])
Expand Down
1 change: 1 addition & 0 deletions R/order_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ order_draws.rvar <- function(x, ...) {
# if ordering is needed, must also merge chains (as out-of-order draws
# imply chain information is no longer meaningful)
if (nchains(x) > 1) {
warn_merge_chains("index")
x <- merge_chains(x)
}
draws_of(x) <- vec_slice(draws_of(x), draw_order)
Expand Down
5 changes: 5 additions & 0 deletions tests/testthat/test-bind_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ test_that("bind_draws works for draws_matrix objects", {
nchains(draws1) + nchains(draws2)
)
expect_equal(variables(draws_new), variables(draws1))
expect_equal(draw_ids(draws_new), seq_len(ndraws(draws_new)))

draws_new <- bind_draws(draws1, draws3, along = "draw")
expect_equal(
Expand Down Expand Up @@ -51,6 +52,7 @@ test_that("bind_draws works for draws_array objects", {
draws_new <- bind_draws(draws1, draws2, along = "chain")
expect_equal(nchains(draws_new), nchains(draws1) + nchains(draws2))
expect_equal(variables(draws_new), variables(draws1))
expect_equal(draw_ids(draws_new), seq_len(ndraws(draws_new)))

draws_new <- bind_draws(draws1, draws3, along = "iteration")
expect_equal(
Expand Down Expand Up @@ -82,6 +84,7 @@ test_that("bind_draws works for draws_df objects", {
draws_new <- bind_draws(draws1, draws2, along = "chain")
expect_equal(nchains(draws_new), nchains(draws1) + nchains(draws2))
expect_equal(variables(draws_new), variables(draws1))
expect_equal(draw_ids(draws_new), seq_len(ndraws(draws_new)))

draws_new <- bind_draws(draws1, draws3, along = "iteration")
expect_equal(
Expand Down Expand Up @@ -150,6 +153,7 @@ test_that("bind_draws works for draws_list objects", {
draws_new <- bind_draws(draws1, draws2, along = "chain")
expect_equal(nchains(draws_new), nchains(draws1) + nchains(draws2))
expect_equal(variables(draws_new), variables(draws1))
expect_equal(draw_ids(draws_new), seq_len(ndraws(draws_new)))

draws_new <- bind_draws(draws1, draws3, along = "iteration")
expect_equal(
Expand Down Expand Up @@ -181,6 +185,7 @@ test_that("bind_draws works for draws_rvars objects", {
draws_new <- bind_draws(draws1, draws2, along = "chain")
expect_equal(nchains(draws_new), nchains(draws1) + nchains(draws2))
expect_equal(variables(draws_new), variables(draws1))
expect_equal(draw_ids(draws_new), seq_len(ndraws(draws_new)))

expect_error(bind_draws(draws1, draws3, along = "iteration"),
"Cannot bind 'draws_rvars' objects along 'iteration'")
Expand Down
11 changes: 11 additions & 0 deletions tests/testthat/test-split_chains.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
test_that("split_chains(<draws_rvar>) works correctly", {
x_array <- array(1:48, dim = c(4, 3, 4))
x_rvar <- rvar(x_array, with_chains = TRUE)
x_draws <- draws_rvars(x = x_rvar)

x_split_array <- abind::abind(x_array[1:2,,], x_array[3:4,,], along = 2)
x_split_rvar <- rvar(x_split_array, with_chains = TRUE)
x_split_draws <- draws_rvars(x = x_split_rvar)

expect_equal(split_chains(x_draws), x_split_draws)
})

0 comments on commit 5e0a269

Please sign in to comment.