Skip to content

Commit

Permalink
test coverage improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
mjskay committed Jan 18, 2024
1 parent 316a81e commit 534d7fb
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 18 deletions.
24 changes: 8 additions & 16 deletions R/rvar-.R
Original file line number Diff line number Diff line change
Expand Up @@ -716,24 +716,16 @@ conform_rvar_nchains_ndraws_weights <- function(rvars, promote_unweighted = TRUE
rvars
}

# Check that the first rvar can be conformed to the dimensions of the second,
# ignoring 1s
check_rvar_dims_first <- function(x, y) {
x_dim <- dim(x)
x_dim_dropped <- as.integer(x_dim[x_dim != 1])
y_dim <- dim(y)
y_dim_dropped <- as.integer(y_dim[y_dim != 1])

if (length(x_dim_dropped) == 0) {
# x can be treated as scalar, do so
dim(x) <- rep(1, length(dim(y)))
} else if (identical(x_dim_dropped, y_dim_dropped)) {
dim(x) <- dim(y)
} else {
stop_no_call("Cannot assign an rvar with dimension ", paste0(x_dim, collapse = ","),
" to an rvar with dimension ", paste0(y_dim, collapse = ","))
#' Check that an rvar is a scalar (length 1)
#' @param x rvar to check
#' @returns x with `dim(x) == 1`, or throws an error if `x` is not scalar.
#' @noRd
check_rvar_is_scalar <- function(x) {
if (length(x) != 1) {
stop_no_call("Cannot insert an rvar with length != 1 into another rvar using `[[`")
}

dim(x) <- 1
x
}

Expand Down
4 changes: 3 additions & 1 deletion R/rvar-cast.R
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,9 @@ vec_restore.rvar <- function(x, ...) {

# find runs where the same underlying draws are in the proxy
different_draws_from_previous <- vapply(seq_along(x)[-1], FUN.VALUE = logical(1), function(i) {
!identical(x[[i]]$draws, x[[i - 1]]$draws) || !identical(x[[i]]$nchains, x[[i - 1]]$nchains)
!identical(x[[i]]$draws, x[[i - 1]]$draws) ||
!identical(x[[i]]$nchains, x[[i - 1]]$nchains) ||
!identical(x[[i]]$log_weights, x[[i - 1]]$log_weights)
})
draws_groups <- cumsum(c(TRUE, different_draws_from_previous))

Expand Down
2 changes: 1 addition & 1 deletion R/rvar-slice.R
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ NULL
`[[<-.rvar` <- function(x, i, ..., value) {
value <- vec_cast(value, x)
c(x, value) %<-% conform_rvar_nchains_ndraws_weights(list(x, value))
value <- check_rvar_dims_first(value, new_rvar(0))
value <- check_rvar_is_scalar(value)
index <- check_rvar_yank_index(x, i, ...)

if (length(index) == 1) {
Expand Down
12 changes: 12 additions & 0 deletions tests/testthat/test-rvar-cast.R
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,18 @@ test_that("casting to/from rvar/distribution objects works", {
expect_error(vctrs::vec_cast(x_mv, null_dist))
})

test_that("vec_c works with rvar and distributions", {
x_dist <- distributional::dist_sample(list(a = 1:2, b = 3:4))
y_dist <- distributional::dist_sample(list(c = 5:6, d = 7:8))
xy_dist <- distributional::dist_sample(list(a = 1:2, b = 3:4, c = 5:6, d = 7:8))
x_rvar <- rvar(matrix(c(1:4), ncol = 2, dimnames = list(NULL, c("a","b"))))
y_rvar <- rvar(matrix(c(5:8), ncol = 2, dimnames = list(NULL, c("c","d"))))
xy_rvar <- rvar(matrix(c(1:8), ncol = 4, dimnames = list(NULL, c("a","b","c","d"))))

expect_equal(vctrs::vec_c(x_dist, y_rvar), xy_dist)
expect_equal(vctrs::vec_c(x_rvar, y_dist), xy_rvar)
})


# type predicates ---------------------------------------------------------

Expand Down
39 changes: 39 additions & 0 deletions tests/testthat/test-rvar-print.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ test_that("print(<rvar_factor>) works", {
regexp = "12 levels: a b c d e f g h i j k l",
all = FALSE
)

x_long <- rvar_factor(combn(letters, 2, paste, collapse = ""))
out <- capture.output(print(x_long, color = FALSE, width = 50))
expect_match(
out,
regexp = "325 levels: ab ac ad ae af ag ah ai aj \\.\\.\\. yz",
all = FALSE
)
})

test_that("print(<rvar_ordered>) works", {
Expand Down Expand Up @@ -255,9 +263,40 @@ test_that("str(<rvar_ordered>) works", {
)
})

test_that("str(<weighted rvar>) works", {
x <- rvar(1:100, log_weights = 2:101)

expect_output(str(weight_draws(rvar(), 1)),
" weighted rvar<1>\\[0\\] "
)
out <- capture.output(str(x))
expect_match(
out,
regexp = " weighted rvar<100>\\[1\\] 99 . 0.96",
all = FALSE
)
expect_match(
out,
regexp = " - log_weights\\(\\*\\)= int \\[1:100\\] 2 3 4 5",
all = FALSE
)
})


# other -------------------------------------------------------------------

test_that("tibble printing works", {
skip_on_cran()

x <- rvar(1:10)
out <- capture.output(print(tibble::tibble(x)))
expect_match(
out,
regexp = " 5.5 . 3",
all = FALSE
)
})

test_that("glimpse on rvar works", {
skip_on_cran()
x_vec <- rvar(array(1:24, dim = c(6,4)))
Expand Down
45 changes: 45 additions & 0 deletions tests/testthat/test-weight_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ test_that("weight_draws works on draws_matrix", {
x2 <- weight_draws(x, log(weights), log = TRUE)
weights2 <- weights(x2)
expect_equal(weights2, weights / sum(weights))

# test replacement of weights
expect_equal(weight_draws(x1, weights2), weight_draws(x, weights2))
})

test_that("weight_draws works on draws_array", {
Expand All @@ -22,6 +25,9 @@ test_that("weight_draws works on draws_array", {
x2 <- weight_draws(x, log(weights), log = TRUE)
weights2 <- weights(x2, normalize = FALSE)
expect_equal(weights2, weights)

# test replacement of weights
expect_equal(weight_draws(x1, weights2), weight_draws(x, weights2))
})

test_that("weight_draws works on draws_df", {
Expand All @@ -35,6 +41,9 @@ test_that("weight_draws works on draws_df", {
x2 <- weight_draws(x, log(weights), log = TRUE)
weights2 <- weights(x2)
expect_equal(weights2, weights / sum(weights))

# test replacement of weights
expect_equal(weight_draws(x1, weights2), weight_draws(x, weights2))
})

test_that("weight_draws works on draws_list", {
Expand All @@ -48,6 +57,9 @@ test_that("weight_draws works on draws_list", {
x2 <- weight_draws(x, log(weights), log = TRUE)
weights2 <- weights(x2, normalize = FALSE)
expect_equal(weights2, weights)

# test replacement of weights
expect_equal(weight_draws(x1, weights2), weight_draws(x, weights2))
})

test_that("weight_draws works on draws_rvars", {
Expand All @@ -61,6 +73,9 @@ test_that("weight_draws works on draws_rvars", {
x2 <- weight_draws(x, log(weights), log = TRUE)
weights2 <- weights(x2, normalize = FALSE)
expect_equal(weights2, weights)

# test replacement of weights
expect_equal(weight_draws(x1, weights2), weight_draws(x, weights2))
})

test_that("weights are propagated to variables in draws_rvars", {
Expand All @@ -83,6 +98,26 @@ test_that("weights are propagated to variables in draws_rvars", {
)
})

# removing weights works --------------------------------------------------

test_that("weights can be removed", {
x <- list(
matrix = as_draws_matrix(example_draws()),
array = as_draws_array(example_draws()),
df = as_draws_df(example_draws()),
list = as_draws_list(example_draws()),
rvars = as_draws_rvars(example_draws()),
rvar = as_draws_rvars(example_draws())$mu
)

weights <- rexp(ndraws(example_draws()))
x_weighted <- lapply(x, weight_draws, weights)

for (type in names(x)) {
expect_equal(weight_draws(x_weighted[[!!type]], NULL), x[[!!type]])
}
})

# conversion preserves weights --------------------------------------------

test_that("conversion between formats preserves weights", {
Expand Down Expand Up @@ -118,3 +153,13 @@ test_that("pareto smoothing smooths weights in weight_draws", {
smoothed <- weight_draws(x, lw, pareto_smooth = TRUE, log = TRUE)
expect_false(all(weights(weighted) == weights(smoothed)))
})

# weights must match draws ------------------------------------------------

test_that("weights must match draws", {
x <- example_draws()
types <- list(as_draws_matrix, as_draws_array, as_draws_df, as_draws_list, as_draws_rvars)
for (type in types) {
expect_error(weight_draws((!!type)(x), 1), "weights must match .* draws")
}
})

0 comments on commit 534d7fb

Please sign in to comment.