Skip to content

Commit

Permalink
Merge pull request #387 from stan-dev/convert-list-of-matrices-to-dra…
Browse files Browse the repository at this point in the history
…ws_array

 Convert lists of matrices to `draws_array` objects
  • Loading branch information
jgabry authored Dec 17, 2024
2 parents a6ab390 + bf68d98 commit aeb26ba
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 18 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: posterior
Title: Tools for Working with Posterior Distributions
Version: 1.6.0
Date: 2024-06-28
Version: 1.6.0.9000
Date: 2024-12-17
Authors@R: c(person("Paul-Christian", "Bürkner", email = "[email protected]", role = c("aut", "cre")),
person("Jonah", "Gabry", email = "[email protected]", role = c("aut")),
person("Matthew", "Kay", email = "[email protected]", role = c("aut")),
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# posterior 1.6.0+

### Enhancements

* Convert lists of matrices to `draws_array` objects.

# posterior 1.6.0

### Enhancements
Expand Down
9 changes: 5 additions & 4 deletions R/as_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,11 @@ closest_draws_format <- function(x) {
out <- "rvars"
} else if (is_draws_list_like(x)) {
out <- "list"
}
else {
stop_no_call("Don't know how to transform an object of class ",
"'", class(x)[1L], "' to any supported draws format.")
} else {
stop_no_call(
"Don't know how to transform an object of class '",
class(x)[1L], "' to any supported draws format."
)
}
paste0("draws_", out)
}
Expand Down
27 changes: 15 additions & 12 deletions R/as_draws_array.R
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,11 @@ as_draws_array.mcmc.list <- function(x, ...) {

# try to convert any R object into a 'draws_array' object
.as_draws_array <- function(x) {
x <- as.array(x)
if (is_matrix_list_like(x)) {
x <- as_array_matrix_list(x)
} else {
x <- as.array(x)
}
new_dimnames <- list(iteration = NULL, chain = NULL, variable = NULL)
if (!is.null(dimnames(x)[[3]])) {
new_dimnames[[3]] <- dimnames(x)[[3]]
Expand Down Expand Up @@ -177,7 +181,14 @@ is_draws_array <- function(x) {

# is an object looking like a 'draws_array' object?
is_draws_array_like <- function(x) {
is.array(x) && length(dim(x)) == 3L
is.array(x) && length(dim(x)) == 3L ||
is_matrix_list_like(x)
}

# is an object likely a list of matrices?
# such an object can be easily converted to a draws_array
is_matrix_list_like <- function(x) {
is.list(x) && length(dim(x[[1]])) == 2L
}

#' Extract parts of a `draws_array` object
Expand Down Expand Up @@ -216,15 +227,8 @@ variance.draws_array <- function(x, ...) {
# convert a list of matrices to an array
as_array_matrix_list <- function(x) {
stopifnot(is.list(x))
if (length(x) == 1) {
tmp <- dimnames(x[[1]])
x <- x[[1]]
dim(x) <- c(dim(x), 1)
dimnames(x) <- tmp
} else {
x <- abind::abind(x, along = 3L)
}
x <- aperm(x, c(1, 3, 2))
x <- abind::abind(x, along = 3L)
aperm(x, c(1, 3, 2))
}

# create an empty draws_array object
Expand All @@ -245,4 +249,3 @@ empty_draws_array <- function(variables = character(0), nchains = 0,
class(out) <- class_draws_array()
out
}

22 changes: 22 additions & 0 deletions tests/testthat/test-as_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,28 @@ test_that("arrays can be transformed to draws_array objects", {
expect_equal(nchains(y), 4)
})

test_that("lists of matrices can be transformed to draws_array objects", {
x <- round(rnorm(200), 2)
x <- matrix(x, nrow = 50)
colnames(x) <- paste0("theta", 1:4)

# one chain
z1 <- list(x)
y <- as_draws(z1)
expect_is(y, "draws_array")
expect_equal(variables(y), colnames(z1[[1]]))
expect_equal(niterations(y), 50)
expect_equal(nchains(y), 1)

# multiple chains
z3 <- list(x, x, x)
y <- as_draws(z3)
expect_is(y, "draws_array")
expect_equal(variables(y), colnames(z3[[1]]))
expect_equal(niterations(y), 50)
expect_equal(nchains(y), 3)
})

test_that("data.frames can be transformed to draws_df objects", {
x <- data.frame(
v1 = rnorm(100),
Expand Down

0 comments on commit aeb26ba

Please sign in to comment.