diff --git a/DESCRIPTION b/DESCRIPTION index a4a4655e..780a1d58 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: posterior Title: Tools for Working with Posterior Distributions -Version: 1.3.1 +Version: 1.3.1.9000 Date: 2022-09-06 Authors@R: c(person("Paul-Christian", "Bürkner", email = "paul.buerkner@gmail.com", role = c("aut", "cre")), person("Jonah", "Gabry", email = "jsg2201@columbia.edu", role = c("aut")), diff --git a/NAMESPACE b/NAMESPACE index 709acd43..b0d692df 100755 --- a/NAMESPACE +++ b/NAMESPACE @@ -90,6 +90,7 @@ S3method(bind_draws,draws_df) S3method(bind_draws,draws_list) S3method(bind_draws,draws_matrix) S3method(bind_draws,draws_rvars) +S3method(bind_draws,list) S3method(c,rvar) S3method(cbind,rvar) S3method(cdf,rvar) diff --git a/NEWS.md b/NEWS.md index 37e36c98..158f780b 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,7 +1,16 @@ +# posterior 1.3.1.9000 + +### Enhancements + +* Allow lists of draws objects to be passed as the first argument to + `bind_draws()` (#253). + + # posterior 1.3.1 * Minor release that fixes some CRAN check failures. + # posterior 1.3.0 ### Enhancements diff --git a/R/bind_draws.R b/R/bind_draws.R index 14c3ca97..6eecac59 100644 --- a/R/bind_draws.R +++ b/R/bind_draws.R @@ -230,7 +230,12 @@ bind_draws.NULL <- function(x, ..., along = "variable") { if (!length(dots)) { stop_no_call("All objects passed to 'bind_draws' are NULL.") } - do.call(bind_draws, dots) + do.call("bind_draws", dots) +} + +#' @export +bind_draws.list <- function(x, ..., along = "variable") { + do.call("bind_draws", c(x, ..., along = along)) } # check if function output is the same across objects diff --git a/tests/testthat/test-bind_draws.R b/tests/testthat/test-bind_draws.R index 6c240db2..6ccaaed4 100644 --- a/tests/testthat/test-bind_draws.R +++ b/tests/testthat/test-bind_draws.R @@ -193,6 +193,21 @@ test_that("bind_draws works for draws_rvars objects", { expect_equal(draws_new, draws1) }) +test_that("bind_draws works for list objects", { + draws1 <- as_draws_df(example_draws()) + draws2 <- subset_draws(draws1, chain = 2) + draws3 <- subset_draws(draws1, chain = 3) + + draws12 <- bind_draws(draws1, draws2, along = "chain") + draws_all <- bind_draws(draws1, draws2, draws3, along = "chain") + expect_equal(bind_draws(list(draws1, draws2), along = "chain"), draws12) + expect_equal(bind_draws(list(draws1, draws2, draws3), along = "chain"), draws_all) + + draws4 <- subset_draws(draws1, chain = 4) + draws_all <- bind_draws(draws2, draws3, draws4, along = "iteration") + expect_equal(bind_draws(list(draws2, draws3, draws4), along = "iteration"), draws_all) +}) + test_that("bind_draws errors if all NULL", { expect_error(bind_draws(NULL, NULL), "All objects passed to 'bind_draws' are NULL") })