From 534d7fb4d8cc10dff2e17ca9a582043aa7c6b022 Mon Sep 17 00:00:00 2001 From: Matthew Kay Date: Thu, 18 Jan 2024 13:45:56 -0600 Subject: [PATCH] test coverage improvements --- R/rvar-.R | 24 ++++++---------- R/rvar-cast.R | 4 ++- R/rvar-slice.R | 2 +- tests/testthat/test-rvar-cast.R | 12 ++++++++ tests/testthat/test-rvar-print.R | 39 ++++++++++++++++++++++++++ tests/testthat/test-weight_draws.R | 45 ++++++++++++++++++++++++++++++ 6 files changed, 108 insertions(+), 18 deletions(-) diff --git a/R/rvar-.R b/R/rvar-.R index 0a16d338..444d6b76 100755 --- a/R/rvar-.R +++ b/R/rvar-.R @@ -716,24 +716,16 @@ conform_rvar_nchains_ndraws_weights <- function(rvars, promote_unweighted = TRUE rvars } -# Check that the first rvar can be conformed to the dimensions of the second, -# ignoring 1s -check_rvar_dims_first <- function(x, y) { - x_dim <- dim(x) - x_dim_dropped <- as.integer(x_dim[x_dim != 1]) - y_dim <- dim(y) - y_dim_dropped <- as.integer(y_dim[y_dim != 1]) - - if (length(x_dim_dropped) == 0) { - # x can be treated as scalar, do so - dim(x) <- rep(1, length(dim(y))) - } else if (identical(x_dim_dropped, y_dim_dropped)) { - dim(x) <- dim(y) - } else { - stop_no_call("Cannot assign an rvar with dimension ", paste0(x_dim, collapse = ","), - " to an rvar with dimension ", paste0(y_dim, collapse = ",")) +#' Check that an rvar is a scalar (length 1) +#' @param x rvar to check +#' @returns x with `dim(x) == 1`, or throws an error if `x` is not scalar. +#' @noRd +check_rvar_is_scalar <- function(x) { + if (length(x) != 1) { + stop_no_call("Cannot insert an rvar with length != 1 into another rvar using `[[`") } + dim(x) <- 1 x } diff --git a/R/rvar-cast.R b/R/rvar-cast.R index 43d61972..8a123d2e 100755 --- a/R/rvar-cast.R +++ b/R/rvar-cast.R @@ -245,7 +245,9 @@ vec_restore.rvar <- function(x, ...) { # find runs where the same underlying draws are in the proxy different_draws_from_previous <- vapply(seq_along(x)[-1], FUN.VALUE = logical(1), function(i) { - !identical(x[[i]]$draws, x[[i - 1]]$draws) || !identical(x[[i]]$nchains, x[[i - 1]]$nchains) + !identical(x[[i]]$draws, x[[i - 1]]$draws) || + !identical(x[[i]]$nchains, x[[i - 1]]$nchains) || + !identical(x[[i]]$log_weights, x[[i - 1]]$log_weights) }) draws_groups <- cumsum(c(TRUE, different_draws_from_previous)) diff --git a/R/rvar-slice.R b/R/rvar-slice.R index fcb8bd99..4ea47bf6 100755 --- a/R/rvar-slice.R +++ b/R/rvar-slice.R @@ -160,7 +160,7 @@ NULL `[[<-.rvar` <- function(x, i, ..., value) { value <- vec_cast(value, x) c(x, value) %<-% conform_rvar_nchains_ndraws_weights(list(x, value)) - value <- check_rvar_dims_first(value, new_rvar(0)) + value <- check_rvar_is_scalar(value) index <- check_rvar_yank_index(x, i, ...) if (length(index) == 1) { diff --git a/tests/testthat/test-rvar-cast.R b/tests/testthat/test-rvar-cast.R index f411c56a..701f2957 100755 --- a/tests/testthat/test-rvar-cast.R +++ b/tests/testthat/test-rvar-cast.R @@ -202,6 +202,18 @@ test_that("casting to/from rvar/distribution objects works", { expect_error(vctrs::vec_cast(x_mv, null_dist)) }) +test_that("vec_c works with rvar and distributions", { + x_dist <- distributional::dist_sample(list(a = 1:2, b = 3:4)) + y_dist <- distributional::dist_sample(list(c = 5:6, d = 7:8)) + xy_dist <- distributional::dist_sample(list(a = 1:2, b = 3:4, c = 5:6, d = 7:8)) + x_rvar <- rvar(matrix(c(1:4), ncol = 2, dimnames = list(NULL, c("a","b")))) + y_rvar <- rvar(matrix(c(5:8), ncol = 2, dimnames = list(NULL, c("c","d")))) + xy_rvar <- rvar(matrix(c(1:8), ncol = 4, dimnames = list(NULL, c("a","b","c","d")))) + + expect_equal(vctrs::vec_c(x_dist, y_rvar), xy_dist) + expect_equal(vctrs::vec_c(x_rvar, y_dist), xy_rvar) +}) + # type predicates --------------------------------------------------------- diff --git a/tests/testthat/test-rvar-print.R b/tests/testthat/test-rvar-print.R index d2a8c492..63bad5ba 100755 --- a/tests/testthat/test-rvar-print.R +++ b/tests/testthat/test-rvar-print.R @@ -80,6 +80,14 @@ test_that("print() works", { regexp = "12 levels: a b c d e f g h i j k l", all = FALSE ) + + x_long <- rvar_factor(combn(letters, 2, paste, collapse = "")) + out <- capture.output(print(x_long, color = FALSE, width = 50)) + expect_match( + out, + regexp = "325 levels: ab ac ad ae af ag ah ai aj \\.\\.\\. yz", + all = FALSE + ) }) test_that("print() works", { @@ -255,9 +263,40 @@ test_that("str() works", { ) }) +test_that("str() works", { + x <- rvar(1:100, log_weights = 2:101) + + expect_output(str(weight_draws(rvar(), 1)), + " weighted rvar<1>\\[0\\] " + ) + out <- capture.output(str(x)) + expect_match( + out, + regexp = " weighted rvar<100>\\[1\\] 99 . 0.96", + all = FALSE + ) + expect_match( + out, + regexp = " - log_weights\\(\\*\\)= int \\[1:100\\] 2 3 4 5", + all = FALSE + ) +}) + # other ------------------------------------------------------------------- +test_that("tibble printing works", { + skip_on_cran() + + x <- rvar(1:10) + out <- capture.output(print(tibble::tibble(x))) + expect_match( + out, + regexp = " 5.5 . 3", + all = FALSE + ) +}) + test_that("glimpse on rvar works", { skip_on_cran() x_vec <- rvar(array(1:24, dim = c(6,4))) diff --git a/tests/testthat/test-weight_draws.R b/tests/testthat/test-weight_draws.R index 0230a83f..1602396c 100644 --- a/tests/testthat/test-weight_draws.R +++ b/tests/testthat/test-weight_draws.R @@ -9,6 +9,9 @@ test_that("weight_draws works on draws_matrix", { x2 <- weight_draws(x, log(weights), log = TRUE) weights2 <- weights(x2) expect_equal(weights2, weights / sum(weights)) + + # test replacement of weights + expect_equal(weight_draws(x1, weights2), weight_draws(x, weights2)) }) test_that("weight_draws works on draws_array", { @@ -22,6 +25,9 @@ test_that("weight_draws works on draws_array", { x2 <- weight_draws(x, log(weights), log = TRUE) weights2 <- weights(x2, normalize = FALSE) expect_equal(weights2, weights) + + # test replacement of weights + expect_equal(weight_draws(x1, weights2), weight_draws(x, weights2)) }) test_that("weight_draws works on draws_df", { @@ -35,6 +41,9 @@ test_that("weight_draws works on draws_df", { x2 <- weight_draws(x, log(weights), log = TRUE) weights2 <- weights(x2) expect_equal(weights2, weights / sum(weights)) + + # test replacement of weights + expect_equal(weight_draws(x1, weights2), weight_draws(x, weights2)) }) test_that("weight_draws works on draws_list", { @@ -48,6 +57,9 @@ test_that("weight_draws works on draws_list", { x2 <- weight_draws(x, log(weights), log = TRUE) weights2 <- weights(x2, normalize = FALSE) expect_equal(weights2, weights) + + # test replacement of weights + expect_equal(weight_draws(x1, weights2), weight_draws(x, weights2)) }) test_that("weight_draws works on draws_rvars", { @@ -61,6 +73,9 @@ test_that("weight_draws works on draws_rvars", { x2 <- weight_draws(x, log(weights), log = TRUE) weights2 <- weights(x2, normalize = FALSE) expect_equal(weights2, weights) + + # test replacement of weights + expect_equal(weight_draws(x1, weights2), weight_draws(x, weights2)) }) test_that("weights are propagated to variables in draws_rvars", { @@ -83,6 +98,26 @@ test_that("weights are propagated to variables in draws_rvars", { ) }) +# removing weights works -------------------------------------------------- + +test_that("weights can be removed", { + x <- list( + matrix = as_draws_matrix(example_draws()), + array = as_draws_array(example_draws()), + df = as_draws_df(example_draws()), + list = as_draws_list(example_draws()), + rvars = as_draws_rvars(example_draws()), + rvar = as_draws_rvars(example_draws())$mu + ) + + weights <- rexp(ndraws(example_draws())) + x_weighted <- lapply(x, weight_draws, weights) + + for (type in names(x)) { + expect_equal(weight_draws(x_weighted[[!!type]], NULL), x[[!!type]]) + } +}) + # conversion preserves weights -------------------------------------------- test_that("conversion between formats preserves weights", { @@ -118,3 +153,13 @@ test_that("pareto smoothing smooths weights in weight_draws", { smoothed <- weight_draws(x, lw, pareto_smooth = TRUE, log = TRUE) expect_false(all(weights(weighted) == weights(smoothed))) }) + +# weights must match draws ------------------------------------------------ + +test_that("weights must match draws", { + x <- example_draws() + types <- list(as_draws_matrix, as_draws_array, as_draws_df, as_draws_list, as_draws_rvars) + for (type in types) { + expect_error(weight_draws((!!type)(x), 1), "weights must match .* draws") + } +})