diff --git a/NEWS.md b/NEWS.md index abc8cc1e..c427e51b 100644 --- a/NEWS.md +++ b/NEWS.md @@ -20,6 +20,8 @@ and `vctrs::vec_proxy_order()`. * Minor future-proofing of `cbind()`, `rbind()`, and `chol()` for R 4.4 (#304). +* Ensure that `bind_draws()` regenerates draw ids when binding along + chains or draws; this also fixes a bug in `split_chains()` (#300). # posterior 1.4.1 diff --git a/R/bind_draws.R b/R/bind_draws.R index 6eecac59..36b6ce1e 100644 --- a/R/bind_draws.R +++ b/R/bind_draws.R @@ -215,7 +215,11 @@ bind_draws.draws_rvars <- function(x, ..., along = "variable") { out <- lapply(seq_along(dots[[1]]), function(var_i) { vars <- lapply(dots, `[[`, var_i) var_draws <- lapply(vars, draws_of) - out <- rvar(abind(var_draws, along = 1), nchains = nchains) + new_draws <- abind(var_draws, along = 1) + # must regenerate draw ids in case binding along draws or chains generates + # duplicate or out of sequence draw ids + rownames(new_draws) <- seq_rows(new_draws) + out <- rvar(new_draws, nchains = nchains) out }) names(out) <- names(dots[[1]]) diff --git a/R/order_draws.R b/R/order_draws.R index 55773290..45e76ecf 100644 --- a/R/order_draws.R +++ b/R/order_draws.R @@ -78,6 +78,7 @@ order_draws.rvar <- function(x, ...) { # if ordering is needed, must also merge chains (as out-of-order draws # imply chain information is no longer meaningful) if (nchains(x) > 1) { + warn_merge_chains("index") x <- merge_chains(x) } draws_of(x) <- vec_slice(draws_of(x), draw_order) diff --git a/tests/testthat/test-bind_draws.R b/tests/testthat/test-bind_draws.R index 6ccaaed4..c4ec8962 100644 --- a/tests/testthat/test-bind_draws.R +++ b/tests/testthat/test-bind_draws.R @@ -23,6 +23,7 @@ test_that("bind_draws works for draws_matrix objects", { nchains(draws1) + nchains(draws2) ) expect_equal(variables(draws_new), variables(draws1)) + expect_equal(draw_ids(draws_new), seq_len(ndraws(draws_new))) draws_new <- bind_draws(draws1, draws3, along = "draw") expect_equal( @@ -51,6 +52,7 @@ test_that("bind_draws works for draws_array objects", { draws_new <- bind_draws(draws1, draws2, along = "chain") expect_equal(nchains(draws_new), nchains(draws1) + nchains(draws2)) expect_equal(variables(draws_new), variables(draws1)) + expect_equal(draw_ids(draws_new), seq_len(ndraws(draws_new))) draws_new <- bind_draws(draws1, draws3, along = "iteration") expect_equal( @@ -82,6 +84,7 @@ test_that("bind_draws works for draws_df objects", { draws_new <- bind_draws(draws1, draws2, along = "chain") expect_equal(nchains(draws_new), nchains(draws1) + nchains(draws2)) expect_equal(variables(draws_new), variables(draws1)) + expect_equal(draw_ids(draws_new), seq_len(ndraws(draws_new))) draws_new <- bind_draws(draws1, draws3, along = "iteration") expect_equal( @@ -150,6 +153,7 @@ test_that("bind_draws works for draws_list objects", { draws_new <- bind_draws(draws1, draws2, along = "chain") expect_equal(nchains(draws_new), nchains(draws1) + nchains(draws2)) expect_equal(variables(draws_new), variables(draws1)) + expect_equal(draw_ids(draws_new), seq_len(ndraws(draws_new))) draws_new <- bind_draws(draws1, draws3, along = "iteration") expect_equal( @@ -181,6 +185,7 @@ test_that("bind_draws works for draws_rvars objects", { draws_new <- bind_draws(draws1, draws2, along = "chain") expect_equal(nchains(draws_new), nchains(draws1) + nchains(draws2)) expect_equal(variables(draws_new), variables(draws1)) + expect_equal(draw_ids(draws_new), seq_len(ndraws(draws_new))) expect_error(bind_draws(draws1, draws3, along = "iteration"), "Cannot bind 'draws_rvars' objects along 'iteration'") diff --git a/tests/testthat/test-split_chains.R b/tests/testthat/test-split_chains.R new file mode 100644 index 00000000..d0b3be48 --- /dev/null +++ b/tests/testthat/test-split_chains.R @@ -0,0 +1,11 @@ +test_that("split_chains() works correctly", { + x_array <- array(1:48, dim = c(4, 3, 4)) + x_rvar <- rvar(x_array, with_chains = TRUE) + x_draws <- draws_rvars(x = x_rvar) + + x_split_array <- abind::abind(x_array[1:2,,], x_array[3:4,,], along = 2) + x_split_rvar <- rvar(x_split_array, with_chains = TRUE) + x_split_draws <- draws_rvars(x = x_split_rvar) + + expect_equal(split_chains(x_draws), x_split_draws) +})