From a6b8cd45453cda7ef2830a459c059fb7d4b1f171 Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Mon, 30 Oct 2023 12:30:32 +1100 Subject: [PATCH 1/5] ensure test for bind_draws(along = chain) tests for correct draw ids --- tests/testthat/test-bind_draws.R | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/testthat/test-bind_draws.R b/tests/testthat/test-bind_draws.R index 6ccaaed4..c4ec8962 100644 --- a/tests/testthat/test-bind_draws.R +++ b/tests/testthat/test-bind_draws.R @@ -23,6 +23,7 @@ test_that("bind_draws works for draws_matrix objects", { nchains(draws1) + nchains(draws2) ) expect_equal(variables(draws_new), variables(draws1)) + expect_equal(draw_ids(draws_new), seq_len(ndraws(draws_new))) draws_new <- bind_draws(draws1, draws3, along = "draw") expect_equal( @@ -51,6 +52,7 @@ test_that("bind_draws works for draws_array objects", { draws_new <- bind_draws(draws1, draws2, along = "chain") expect_equal(nchains(draws_new), nchains(draws1) + nchains(draws2)) expect_equal(variables(draws_new), variables(draws1)) + expect_equal(draw_ids(draws_new), seq_len(ndraws(draws_new))) draws_new <- bind_draws(draws1, draws3, along = "iteration") expect_equal( @@ -82,6 +84,7 @@ test_that("bind_draws works for draws_df objects", { draws_new <- bind_draws(draws1, draws2, along = "chain") expect_equal(nchains(draws_new), nchains(draws1) + nchains(draws2)) expect_equal(variables(draws_new), variables(draws1)) + expect_equal(draw_ids(draws_new), seq_len(ndraws(draws_new))) draws_new <- bind_draws(draws1, draws3, along = "iteration") expect_equal( @@ -150,6 +153,7 @@ test_that("bind_draws works for draws_list objects", { draws_new <- bind_draws(draws1, draws2, along = "chain") expect_equal(nchains(draws_new), nchains(draws1) + nchains(draws2)) expect_equal(variables(draws_new), variables(draws1)) + expect_equal(draw_ids(draws_new), seq_len(ndraws(draws_new))) draws_new <- bind_draws(draws1, draws3, along = "iteration") expect_equal( @@ -181,6 +185,7 @@ test_that("bind_draws works for draws_rvars objects", { draws_new <- bind_draws(draws1, draws2, along = "chain") expect_equal(nchains(draws_new), nchains(draws1) + nchains(draws2)) expect_equal(variables(draws_new), variables(draws1)) + expect_equal(draw_ids(draws_new), seq_len(ndraws(draws_new))) expect_error(bind_draws(draws1, draws3, along = "iteration"), "Cannot bind 'draws_rvars' objects along 'iteration'") From c1ed42e2b42e809d76eeb81cdd53a35de6323388 Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Mon, 30 Oct 2023 12:30:55 +1100 Subject: [PATCH 2/5] add test for split_chains() --- tests/testthat/test-split_chains.R | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 tests/testthat/test-split_chains.R diff --git a/tests/testthat/test-split_chains.R b/tests/testthat/test-split_chains.R new file mode 100644 index 00000000..d0b3be48 --- /dev/null +++ b/tests/testthat/test-split_chains.R @@ -0,0 +1,11 @@ +test_that("split_chains() works correctly", { + x_array <- array(1:48, dim = c(4, 3, 4)) + x_rvar <- rvar(x_array, with_chains = TRUE) + x_draws <- draws_rvars(x = x_rvar) + + x_split_array <- abind::abind(x_array[1:2,,], x_array[3:4,,], along = 2) + x_split_rvar <- rvar(x_split_array, with_chains = TRUE) + x_split_draws <- draws_rvars(x = x_split_rvar) + + expect_equal(split_chains(x_draws), x_split_draws) +}) From 04433255914de458472a99e34088fd5eb6507b55 Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Mon, 30 Oct 2023 12:31:45 +1100 Subject: [PATCH 3/5] order_draws() should warn if it must merge chains --- R/order_draws.R | 1 + 1 file changed, 1 insertion(+) diff --git a/R/order_draws.R b/R/order_draws.R index 55773290..45e76ecf 100644 --- a/R/order_draws.R +++ b/R/order_draws.R @@ -78,6 +78,7 @@ order_draws.rvar <- function(x, ...) { # if ordering is needed, must also merge chains (as out-of-order draws # imply chain information is no longer meaningful) if (nchains(x) > 1) { + warn_merge_chains("index") x <- merge_chains(x) } draws_of(x) <- vec_slice(draws_of(x), draw_order) From f55034be0bfbbe4b94159e25b924b4a95fb59530 Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Mon, 30 Oct 2023 12:32:17 +1100 Subject: [PATCH 4/5] todo for fixing #300 --- R/bind_draws.R | 1 + 1 file changed, 1 insertion(+) diff --git a/R/bind_draws.R b/R/bind_draws.R index 6eecac59..6611c2b7 100644 --- a/R/bind_draws.R +++ b/R/bind_draws.R @@ -202,6 +202,7 @@ bind_draws.draws_rvars <- function(x, ..., along = "variable") { } else if (along == "iteration") { stop_no_call("Cannot bind 'draws_rvars' objects along 'iteration'.") } else if (along %in% c("chain", "draw")) { + # TODO here: make sure both "draw" and "chain" result in new draw_ids (add tests) check_same_fun_output(dots, variables) if (along == "chain") { check_same_fun_output(dots, iteration_ids) From e551e5daead9f9fb1dd8d339915e02a40cf39e60 Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Mon, 30 Oct 2023 21:08:33 -0500 Subject: [PATCH 5/5] ensure bind_draws(along = "chain") for rvars regenerates draw ids, closing #300 --- NEWS.md | 2 ++ R/bind_draws.R | 7 +++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/NEWS.md b/NEWS.md index abc8cc1e..c427e51b 100644 --- a/NEWS.md +++ b/NEWS.md @@ -20,6 +20,8 @@ and `vctrs::vec_proxy_order()`. * Minor future-proofing of `cbind()`, `rbind()`, and `chol()` for R 4.4 (#304). +* Ensure that `bind_draws()` regenerates draw ids when binding along + chains or draws; this also fixes a bug in `split_chains()` (#300). # posterior 1.4.1 diff --git a/R/bind_draws.R b/R/bind_draws.R index 6611c2b7..36b6ce1e 100644 --- a/R/bind_draws.R +++ b/R/bind_draws.R @@ -202,7 +202,6 @@ bind_draws.draws_rvars <- function(x, ..., along = "variable") { } else if (along == "iteration") { stop_no_call("Cannot bind 'draws_rvars' objects along 'iteration'.") } else if (along %in% c("chain", "draw")) { - # TODO here: make sure both "draw" and "chain" result in new draw_ids (add tests) check_same_fun_output(dots, variables) if (along == "chain") { check_same_fun_output(dots, iteration_ids) @@ -216,7 +215,11 @@ bind_draws.draws_rvars <- function(x, ..., along = "variable") { out <- lapply(seq_along(dots[[1]]), function(var_i) { vars <- lapply(dots, `[[`, var_i) var_draws <- lapply(vars, draws_of) - out <- rvar(abind(var_draws, along = 1), nchains = nchains) + new_draws <- abind(var_draws, along = 1) + # must regenerate draw ids in case binding along draws or chains generates + # duplicate or out of sequence draw ids + rownames(new_draws) <- seq_rows(new_draws) + out <- rvar(new_draws, nchains = nchains) out }) names(out) <- names(dots[[1]])