diff --git a/R/draws-index.R b/R/draws-index.R index 8cdfa990..5f1109eb 100644 --- a/R/draws-index.R +++ b/R/draws-index.R @@ -478,15 +478,18 @@ ndraws.rvar <- function(x) { # @param regex should 'variables' be treated as regular expressions? # @param scalar_only should only scalar variables be matched? check_existing_variables <- function(variables, x, regex = FALSE, - scalar_only = FALSE) { + scalar_only = FALSE, exclude = FALSE) { check_draws_object(x) if (is.null(variables)) { return(NULL) } + regex <- as_one_logical(regex) scalar_only <- as_one_logical(scalar_only) + exclude <- as_one_logical(exclude) variables <- unique(as.character(variables)) all_variables <- variables(x, reserved = TRUE) + if (regex) { tmp <- named_list(variables) for (i in seq_along(variables)) { @@ -529,6 +532,12 @@ check_existing_variables <- function(variables, x, regex = FALSE, stop_no_call("The following variables are missing in the draws object: ", comma(missing_variables)) } + + # handle excluding variables for subset_draws + if (exclude) { + variables <- setdiff(all_variables, variables) + } + invisible(variables) } @@ -564,12 +573,13 @@ check_reserved_variables <- function(variables) { # check validity of iteration indices # @param unique should the returned IDs be unique? -check_iteration_ids <- function(iteration_ids, x, unique = TRUE) { +check_iteration_ids <- function(iteration_ids, x, unique = TRUE, exclude = FALSE) { check_draws_object(x) if (is.null(iteration_ids)) { return(NULL) } unique <- as_one_logical(unique) + exclude <- as_one_logical(exclude) iteration_ids <- as.integer(iteration_ids) if (unique) { iteration_ids <- unique(iteration_ids) @@ -584,17 +594,24 @@ check_iteration_ids <- function(iteration_ids, x, unique = TRUE) { stop_no_call("Tried to subset iterations up to '", max_iteration, "' ", "but the object only has '", niterations, "' iterations.") } + + # handle exclude iterations in subset_draws + if (exclude) { + iteration_ids <- setdiff(iteration_ids(x), iteration_ids) + } + invisible(iteration_ids) } # check validity of chain indices # @param unique should the returned IDs be unique? -check_chain_ids <- function(chain_ids, x, unique = TRUE) { +check_chain_ids <- function(chain_ids, x, unique = TRUE, exclude = FALSE) { check_draws_object(x) if (is.null(chain_ids)) { return(NULL) } unique <- as_one_logical(unique) + exclude <- as_one_logical(exclude) chain_ids <- as.integer(chain_ids) if (unique) { chain_ids <- unique(chain_ids) @@ -609,17 +626,23 @@ check_chain_ids <- function(chain_ids, x, unique = TRUE) { stop_no_call("Tried to subset chains up to '", max_chain, "' ", "but the object only has '", nchains, "' chains.") } + + if (exclude) { + chain_ids <- setdiff(chain_ids(x), chain_ids) + } + invisible(chain_ids) } # check validity of draw indices # @param unique should the returned IDs be unique? -check_draw_ids <- function(draw_ids, x, unique = TRUE) { +check_draw_ids <- function(draw_ids, x, unique = TRUE, exclude = FALSE) { check_draws_object(x) if (is.null(draw_ids)) { return(NULL) } unique <- as_one_logical(unique) + exclude <- as_one_logical(exclude) draw_ids <- as.integer(draw_ids) if (unique) { draw_ids <- unique(draw_ids) @@ -634,5 +657,10 @@ check_draw_ids <- function(draw_ids, x, unique = TRUE) { stop_no_call("Tried to subset draws up to '", max_draw, "' ", "but the object only has '", ndraws, "' draws.") } + + if (exclude) { + draw_ids <- setdiff(draw_ids(x), draw_ids) + } + invisible(draw_ids) } diff --git a/R/subset_draws.R b/R/subset_draws.R index 8465f14f..7802c63f 100644 --- a/R/subset_draws.R +++ b/R/subset_draws.R @@ -3,19 +3,25 @@ #' Subset [`draws`] objects by variables, iterations, chains, and draws indices. #' #' @template args-methods-x -#' @param variable (character vector) The variables to select. All elements of -#' non-scalar variables can be selected at once. +#' @param variable (character vector) The variables to select. All +#' elements of non-scalar variables can be selected at once. #' @param iteration (integer vector) The iteration indices to select. #' @param chain (integer vector) The chain indices to select. -#' @param draw (integer vector) The draw indices to be select. Subsetting draw -#' indices will lead to an automatic merging of chains via [`merge_chains`]. +#' @param draw (integer vector) The draw indices to be +#' select. Subsetting draw indices will lead to an automatic merging +#' of chains via [`merge_chains`]. #' @param regex (logical) Should `variable` should be treated as a -#' (vector of) regular expressions? Any variable in `x` matching at least one -#' of the regular expressions will be selected. Defaults to `FALSE`. -#' @param unique (logical) Should duplicated selection of chains, iterations, or -#' draws be allowed? If `TRUE` (the default) only unique chains, iterations, -#' and draws are selected regardless of how often they appear in the -#' respective selecting arguments. +#' (vector of) regular expressions? Any variable in `x` matching at +#' least one of the regular expressions will be selected. Defaults +#' to `FALSE`. +#' @param unique (logical) Should duplicated selection of chains, +#' iterations, or draws be allowed? If `TRUE` (the default) only +#' unique chains, iterations, and draws are selected regardless of +#' how often they appear in the respective selecting arguments. +#' @param exclude (logical) Should the selected subset be excluded? +#' If `FALSE` (the default) only the selected subset will be +#' returned. If `TRUE` everything but the selected subset will be +#' returned. #' #' @template args-methods-dots #' @template return-draws @@ -46,15 +52,16 @@ subset_draws <- function(x, ...) { #' @export subset_draws.draws_matrix <- function(x, variable = NULL, iteration = NULL, chain = NULL, draw = NULL, regex = FALSE, - unique = TRUE, ...) { + unique = TRUE, exclude = FALSE, ...) { if (all_null(variable, iteration, chain, draw)) { return(x) } x <- repair_draws(x) - variable <- check_existing_variables(variable, x, regex = regex) - iteration <- check_iteration_ids(iteration, x, unique = unique) - chain <- check_chain_ids(chain, x, unique = unique) - draw <- check_draw_ids(draw, x, unique = unique) + variable <- check_existing_variables(variable, x, regex = regex, exclude = exclude) + iteration <- check_iteration_ids(iteration, x, unique = unique, exclude = exclude) + chain <- check_chain_ids(chain, x, unique = unique, exclude = exclude) + draw <- check_draw_ids(draw, x, unique = unique, exclude = exclude) + x <- prepare_subsetting(x, iteration, chain, draw) x <- .subset_draws(x, iteration, chain, draw, variable, reserved = TRUE) if (!is.null(chain) || !is.null(iteration)) { @@ -67,15 +74,17 @@ subset_draws.draws_matrix <- function(x, variable = NULL, iteration = NULL, #' @export subset_draws.draws_array <- function(x, variable = NULL, iteration = NULL, chain = NULL, draw = NULL, regex = FALSE, - unique = TRUE, ...) { + unique = TRUE, exclude = FALSE, ...) { if (all_null(variable, iteration, chain, draw)) { return(x) } + x <- repair_draws(x) - variable <- check_existing_variables(variable, x, regex = regex) - iteration <- check_iteration_ids(iteration, x, unique = unique) - chain <- check_chain_ids(chain, x, unique = unique) - draw <- check_draw_ids(draw, x, unique = unique) + variable <- check_existing_variables(variable, x, regex = regex, exclude = exclude) + iteration <- check_iteration_ids(iteration, x, unique = unique, exclude = exclude) + chain <- check_chain_ids(chain, x, unique = unique, exclude = exclude) + draw <- check_draw_ids(draw, x, unique = unique, exclude = exclude) + x <- prepare_subsetting(x, iteration, chain, draw) if (!is.null(draw)) { iteration <- draw @@ -91,16 +100,18 @@ subset_draws.draws_array <- function(x, variable = NULL, iteration = NULL, #' @export subset_draws.draws_df <- function(x, variable = NULL, iteration = NULL, chain = NULL, draw = NULL, regex = FALSE, - unique = TRUE, ...) { + unique = TRUE, exclude = FALSE, ...) { if (all_null(variable, iteration, chain, draw)) { return(x) } + x <- repair_draws(x) unique <- as_one_logical(unique) - variable <- check_existing_variables(variable, x, regex = regex) - iteration <- check_iteration_ids(iteration, x, unique = unique) - chain <- check_chain_ids(chain, x, unique = unique) - draw <- check_draw_ids(draw, x, unique = unique) + variable <- check_existing_variables(variable, x, regex = regex, exclude= exclude) + iteration <- check_iteration_ids(iteration, x, unique = unique, exclude= exclude) + chain <- check_chain_ids(chain, x, unique = unique, exclude= exclude) + draw <- check_draw_ids(draw, x, unique = unique, exclude= exclude) + x <- prepare_subsetting(x, iteration, chain, draw) x <- .subset_draws( x, iteration, chain, draw, variable, @@ -113,15 +124,17 @@ subset_draws.draws_df <- function(x, variable = NULL, iteration = NULL, #' @export subset_draws.draws_list <- function(x, variable = NULL, iteration = NULL, chain = NULL, draw = NULL, regex = FALSE, - unique = TRUE, ...) { + unique = TRUE, exclude = FALSE, ...) { if (all_null(variable, iteration, chain, draw)) { return(x) } + x <- repair_draws(x) - variable <- check_existing_variables(variable, x, regex = regex) - iteration <- check_iteration_ids(iteration, x, unique = unique) - chain <- check_chain_ids(chain, x, unique = unique) - draw <- check_draw_ids(draw, x, unique = unique) + variable <- check_existing_variables(variable, x, regex = regex, exclude = exclude) + iteration <- check_iteration_ids(iteration, x, unique = unique, exclude = exclude) + chain <- check_chain_ids(chain, x, unique = unique, exclude = exclude) + draw <- check_draw_ids(draw, x, unique = unique, exclude = exclude) + x <- prepare_subsetting(x, iteration, chain, draw) if (!is.null(draw)) { iteration <- draw @@ -137,15 +150,17 @@ subset_draws.draws_list <- function(x, variable = NULL, iteration = NULL, #' @export subset_draws.draws_rvars <- function(x, variable = NULL, iteration = NULL, chain = NULL, draw = NULL, regex = FALSE, - unique = TRUE, ...) { + unique = TRUE, exclude = FALSE, ...) { if (all_null(variable, iteration, chain, draw)) { return(x) } + x <- repair_draws(x) - variable <- check_existing_variables(variable, x, regex = regex) - iteration <- check_iteration_ids(iteration, x, unique = unique) - chain <- check_chain_ids(chain, x, unique = unique) - draw <- check_draw_ids(draw, x, unique = unique) + variable <- check_existing_variables(variable, x, regex = regex, exclude = exclude) + iteration <- check_iteration_ids(iteration, x, unique = unique, exclude= exclude) + chain <- check_chain_ids(chain, x, unique = unique, exclude= exclude) + draw <- check_draw_ids(draw, x, unique = unique, exclude= exclude) + x <- prepare_subsetting(x, iteration, chain, draw) if (!is.null(draw)) { iteration <- draw diff --git a/man/subset_draws.Rd b/man/subset_draws.Rd index 26d423d4..c5698e08 100644 --- a/man/subset_draws.Rd +++ b/man/subset_draws.Rd @@ -21,6 +21,7 @@ subset_draws(x, ...) draw = NULL, regex = FALSE, unique = TRUE, + exclude = FALSE, ... ) @@ -32,6 +33,7 @@ subset_draws(x, ...) draw = NULL, regex = FALSE, unique = TRUE, + exclude = FALSE, ... ) @@ -43,6 +45,7 @@ subset_draws(x, ...) draw = NULL, regex = FALSE, unique = TRUE, + exclude = FALSE, ... ) @@ -54,6 +57,7 @@ subset_draws(x, ...) draw = NULL, regex = FALSE, unique = TRUE, + exclude = FALSE, ... ) @@ -65,6 +69,7 @@ subset_draws(x, ...) draw = NULL, regex = FALSE, unique = TRUE, + exclude = FALSE, ... ) @@ -78,24 +83,30 @@ is defined.} \item{...}{Arguments passed to individual methods (if applicable).} -\item{variable}{(character vector) The variables to select. All elements of -non-scalar variables can be selected at once.} +\item{variable}{(character vector) The variables to select. All +elements of non-scalar variables can be selected at once.} \item{iteration}{(integer vector) The iteration indices to select.} \item{chain}{(integer vector) The chain indices to select.} -\item{draw}{(integer vector) The draw indices to be select. Subsetting draw -indices will lead to an automatic merging of chains via \code{\link{merge_chains}}.} +\item{draw}{(integer vector) The draw indices to be +select. Subsetting draw indices will lead to an automatic merging +of chains via \code{\link{merge_chains}}.} \item{regex}{(logical) Should \code{variable} should be treated as a -(vector of) regular expressions? Any variable in \code{x} matching at least one -of the regular expressions will be selected. Defaults to \code{FALSE}.} - -\item{unique}{(logical) Should duplicated selection of chains, iterations, or -draws be allowed? If \code{TRUE} (the default) only unique chains, iterations, -and draws are selected regardless of how often they appear in the -respective selecting arguments.} +(vector of) regular expressions? Any variable in \code{x} matching at +least one of the regular expressions will be selected. Defaults +to \code{FALSE}.} + +\item{unique}{(logical) Should duplicated selection of chains, +iterations, or draws be allowed? If \code{TRUE} (the default) only +unique chains, iterations, and draws are selected regardless of +how often they appear in the respective selecting arguments.} + +\item{exclude}{(logical) Should the selected subset be excluded? +If \code{FALSE} (the default) the selection will be returned. If +\code{TRUE} all but the selected subset will be returned.} } \value{ A \code{draws} object of the same class as \code{x}. diff --git a/tests/testthat/test-subset_draws.R b/tests/testthat/test-subset_draws.R index a5170ac4..bae8bc2f 100644 --- a/tests/testthat/test-subset_draws.R +++ b/tests/testthat/test-subset_draws.R @@ -1,4 +1,5 @@ test_that("subset_draws works correctly for draws_matrix objects", { + x <- as_draws_matrix(example_draws()) x_sub <- subset_draws(x, variable = c("mu", "tau"), iteration = 5:10) x_sub2 <- x[c(5:10, 105:110, 205:210, 305:310), c("mu", "tau")] @@ -12,6 +13,18 @@ test_that("subset_draws works correctly for draws_matrix objects", { x <- weight_draws(x, rep(1, ndraws(x))) x_sub <- subset_draws(x, variable = "mu") expect_equal(variables(x_sub, reserved = TRUE), c("mu", ".log_weight")) + + x_sub <- subset_draws(x, variable = "mu", chain = c(1, 2, 3), exclude = TRUE) + expect_equal(setdiff(variables(x, reserved = TRUE), "mu"), variables(x_sub, reserved = TRUE)) + expect_equal(nchains(x_sub), 1) + + x_sub <- subset_draws(x, draw = c(1, 2, 3), exclude = TRUE) + expect_equal(ndraws(x) - 3, ndraws(x_sub)) + + x_sub <- subset_draws(x, iteration = c(1, 2, 3), exclude = TRUE) + expect_equal(ndraws(x) - 3 * nchains(x), ndraws(x_sub)) + + }) test_that("subset_draws works correctly for draws_array objects", { @@ -35,6 +48,16 @@ test_that("subset_draws works correctly for draws_array objects", { x <- weight_draws(x, rep(1, ndraws(x))) x_sub <- subset_draws(x, variable = "mu") expect_equal(variables(x_sub, reserved = TRUE), c("mu", ".log_weight")) + + x_sub <- subset_draws(x, variable = "mu", chain = c(1, 2, 3), exclude = TRUE) + expect_equal(setdiff(variables(x, reserved = TRUE), "mu"), variables(x_sub, reserved = TRUE)) + expect_equal(nchains(x_sub), 1) + + x_sub <- subset_draws(x, draw = c(1, 2, 3), exclude = TRUE) + expect_equal(ndraws(x) - 3, ndraws(x_sub)) + + x_sub <- subset_draws(x, iteration = c(1, 2, 3), exclude = TRUE) + expect_equal(ndraws(x) - 3 * nchains(x), ndraws(x_sub)) }) test_that("subset_draws works correctly for draws_df objects", { @@ -51,6 +74,16 @@ test_that("subset_draws works correctly for draws_df objects", { x <- weight_draws(x, rep(1, ndraws(x))) x_sub <- subset_draws(x, variable = "mu") expect_equal(names(x_sub), c("mu", ".log_weight", ".chain", ".iteration", ".draw")) + + x_sub <- subset_draws(x, variable = "mu", chain = c(1, 2, 3), exclude = TRUE) + expect_equal(setdiff(variables(x, reserved = TRUE), "mu"), variables(x_sub, reserved = TRUE)) + expect_equal(nchains(x_sub), 1) + + x_sub <- subset_draws(x, draw = c(1, 2, 3), exclude = TRUE) + expect_equal(ndraws(x) - 3, ndraws(x_sub)) + + x_sub <- subset_draws(x, iteration = c(1, 2, 3), exclude = TRUE) + expect_equal(ndraws(x) - 3 * nchains(x), ndraws(x_sub)) }) test_that("subset_draws works correctly for draws_list objects", { @@ -73,6 +106,16 @@ test_that("subset_draws works correctly for draws_list objects", { x <- weight_draws(x, rep(1, ndraws(x))) x_sub <- subset_draws(x, variable = "mu") expect_equal(variables(x_sub, reserved = TRUE), c("mu", ".log_weight")) + + x_sub <- subset_draws(x, variable = "mu", chain = c(1, 2, 3), exclude = TRUE) + expect_equal(setdiff(variables(x, reserved = TRUE), "mu"), variables(x_sub, reserved = TRUE)) + expect_equal(nchains(x_sub), 1) + + x_sub <- subset_draws(x, draw = c(1, 2, 3), exclude = TRUE) + expect_equal(ndraws(x) - 3, ndraws(x_sub)) + + x_sub <- subset_draws(x, iteration = c(1, 2, 3), exclude = TRUE) + expect_equal(ndraws(x) - 3 * nchains(x), ndraws(x_sub)) }) test_that("subset_draws works correctly for draws_rvars objects", { @@ -95,6 +138,16 @@ test_that("subset_draws works correctly for draws_rvars objects", { x <- weight_draws(x, rep(1, ndraws(x))) x_sub <- subset_draws(x, variable = "mu") expect_equal(variables(x_sub, reserved = TRUE), c("mu", ".log_weight")) + + x_sub <- subset_draws(x, variable = "mu", chain = c(1, 2, 3), exclude = TRUE) + expect_equal(setdiff(variables(x, reserved = TRUE), "mu"), variables(x_sub, reserved = TRUE)) + expect_equal(nchains(x_sub), 1) + + x_sub <- subset_draws(x, draw = c(1, 2, 3), exclude = TRUE) + expect_equal(ndraws(x) - 3, ndraws(x_sub)) + + x_sub <- subset_draws(x, iteration = c(1, 2, 3), exclude = TRUE) + expect_equal(ndraws(x) - 3 * nchains(x), ndraws(x_sub)) }) test_that("subset_draws works correctly for rvar objects", { @@ -111,6 +164,17 @@ test_that("subset_draws works correctly for rvar objects", { "Merging chains in order to subset via 'draw'" ) expect_equal(niterations(x_sub), 3) + + x_sub <- subset_draws(x, iteration = c(1, 2), chain = c(1, 2, 3), exclude = TRUE) + expect_equal(niterations(x_sub), niterations(x) - 2) + expect_equal(nchains(x_sub), 1) + + x_sub <- subset_draws(x, draw = c(1, 2, 3), exclude = TRUE) + expect_equal(ndraws(x) - 3, ndraws(x_sub)) + + x_sub <- subset_draws(x, iteration = c(1, 2, 3), exclude = TRUE) + expect_equal(ndraws(x) - 3 * nchains(x), ndraws(x_sub)) + }) test_that("variables can be subsetted via regular expressions", {