From 43e6b4be0548be96212b1e9d23413e39e889684d Mon Sep 17 00:00:00 2001 From: n-kall Date: Thu, 11 Jan 2024 13:07:34 +0200 Subject: [PATCH 1/9] add exclude argument to subset_draws --- R/subset_draws.R | 112 ++++++++++++++++++++++++++++++++++++++------ man/subset_draws.Rd | 34 +++++++++----- 2 files changed, 120 insertions(+), 26 deletions(-) diff --git a/R/subset_draws.R b/R/subset_draws.R index 8465f14f..650660d5 100644 --- a/R/subset_draws.R +++ b/R/subset_draws.R @@ -3,19 +3,25 @@ #' Subset [`draws`] objects by variables, iterations, chains, and draws indices. #' #' @template args-methods-x -#' @param variable (character vector) The variables to select. All elements of -#' non-scalar variables can be selected at once. +#' @param variable (character vector) The variables to select. All +#' elements of non-scalar variables can be selected at once. #' @param iteration (integer vector) The iteration indices to select. #' @param chain (integer vector) The chain indices to select. -#' @param draw (integer vector) The draw indices to be select. Subsetting draw -#' indices will lead to an automatic merging of chains via [`merge_chains`]. +#' @param draw (integer vector) The draw indices to be +#' select. Subsetting draw indices will lead to an automatic merging +#' of chains via [`merge_chains`]. #' @param regex (logical) Should `variable` should be treated as a -#' (vector of) regular expressions? Any variable in `x` matching at least one -#' of the regular expressions will be selected. Defaults to `FALSE`. -#' @param unique (logical) Should duplicated selection of chains, iterations, or -#' draws be allowed? If `TRUE` (the default) only unique chains, iterations, -#' and draws are selected regardless of how often they appear in the -#' respective selecting arguments. +#' (vector of) regular expressions? Any variable in `x` matching at +#' least one of the regular expressions will be selected. Defaults +#' to `FALSE`. +#' @param unique (logical) Should duplicated selection of chains, +#' iterations, or draws be allowed? If `TRUE` (the default) only +#' unique chains, iterations, and draws are selected regardless of +#' how often they appear in the respective selecting arguments. +#' @param exclude (logical) Should the matched selection be excluded? +#' If `FALSE` (the default) the matched subset of draws will be +#' returned. If `TRUE` the draws excluding the matched subset will +#' be returned. #' #' @template args-methods-dots #' @template return-draws @@ -46,7 +52,7 @@ subset_draws <- function(x, ...) { #' @export subset_draws.draws_matrix <- function(x, variable = NULL, iteration = NULL, chain = NULL, draw = NULL, regex = FALSE, - unique = TRUE, ...) { + unique = TRUE, exclude = FALSE, ...) { if (all_null(variable, iteration, chain, draw)) { return(x) } @@ -55,6 +61,14 @@ subset_draws.draws_matrix <- function(x, variable = NULL, iteration = NULL, iteration <- check_iteration_ids(iteration, x, unique = unique) chain <- check_chain_ids(chain, x, unique = unique) draw <- check_draw_ids(draw, x, unique = unique) + + if (exclude) { + variable <- setdiff(variables(x), variable) + iteration <- setdiff(iteration_ids(x), iteration) + chain <- setdiff(chain_ids(x), chain) + draw <- setdiff(draw_ids(x), draw) + } + x <- prepare_subsetting(x, iteration, chain, draw) x <- .subset_draws(x, iteration, chain, draw, variable, reserved = TRUE) if (!is.null(chain) || !is.null(iteration)) { @@ -67,15 +81,32 @@ subset_draws.draws_matrix <- function(x, variable = NULL, iteration = NULL, #' @export subset_draws.draws_array <- function(x, variable = NULL, iteration = NULL, chain = NULL, draw = NULL, regex = FALSE, - unique = TRUE, ...) { + unique = TRUE, exclude = FALSE, ...) { if (all_null(variable, iteration, chain, draw)) { return(x) } + x <- repair_draws(x) variable <- check_existing_variables(variable, x, regex = regex) iteration <- check_iteration_ids(iteration, x, unique = unique) chain <- check_chain_ids(chain, x, unique = unique) draw <- check_draw_ids(draw, x, unique = unique) + + if (exclude) { + if (!is.null(variable)) { + variable <- setdiff(variables(x), variable) + } + if (!is.null(iteration)) { + iteration <- setdiff(iteration_ids(x), iteration) + } + if (!is.null(chain)) { + chain <- setdiff(chain_ids(x), chain) + } + if (!is.null(draw)) { + draw <- setdiff(draw_ids(x), draw) + } + } + x <- prepare_subsetting(x, iteration, chain, draw) if (!is.null(draw)) { iteration <- draw @@ -91,16 +122,33 @@ subset_draws.draws_array <- function(x, variable = NULL, iteration = NULL, #' @export subset_draws.draws_df <- function(x, variable = NULL, iteration = NULL, chain = NULL, draw = NULL, regex = FALSE, - unique = TRUE, ...) { + unique = TRUE, exclude = FALSE, ...) { if (all_null(variable, iteration, chain, draw)) { return(x) } + x <- repair_draws(x) unique <- as_one_logical(unique) variable <- check_existing_variables(variable, x, regex = regex) iteration <- check_iteration_ids(iteration, x, unique = unique) chain <- check_chain_ids(chain, x, unique = unique) draw <- check_draw_ids(draw, x, unique = unique) + + if (exclude) { + if (!is.null(variable)) { + variable <- setdiff(variables(x), variable) + } + if (!is.null(iteration)) { + iteration <- setdiff(iteration_ids(x), iteration) + } + if (!is.null(chain)) { + chain <- setdiff(chain_ids(x), chain) + } + if (!is.null(draw)) { + draw <- setdiff(draw_ids(x), draw) + } + } + x <- prepare_subsetting(x, iteration, chain, draw) x <- .subset_draws( x, iteration, chain, draw, variable, @@ -113,15 +161,32 @@ subset_draws.draws_df <- function(x, variable = NULL, iteration = NULL, #' @export subset_draws.draws_list <- function(x, variable = NULL, iteration = NULL, chain = NULL, draw = NULL, regex = FALSE, - unique = TRUE, ...) { + unique = TRUE, exclude = FALSE, ...) { if (all_null(variable, iteration, chain, draw)) { return(x) } + x <- repair_draws(x) variable <- check_existing_variables(variable, x, regex = regex) iteration <- check_iteration_ids(iteration, x, unique = unique) chain <- check_chain_ids(chain, x, unique = unique) draw <- check_draw_ids(draw, x, unique = unique) + + if (exclude) { + if (!is.null(variable)) { + variable <- setdiff(variables(x), variable) + } + if (!is.null(iteration)) { + iteration <- setdiff(iteration_ids(x), iteration) + } + if (!is.null(chain)) { + chain <- setdiff(chain_ids(x), chain) + } + if (!is.null(draw)) { + draw <- setdiff(draw_ids(x), draw) + } + } + x <- prepare_subsetting(x, iteration, chain, draw) if (!is.null(draw)) { iteration <- draw @@ -137,15 +202,32 @@ subset_draws.draws_list <- function(x, variable = NULL, iteration = NULL, #' @export subset_draws.draws_rvars <- function(x, variable = NULL, iteration = NULL, chain = NULL, draw = NULL, regex = FALSE, - unique = TRUE, ...) { + unique = TRUE, exclude = FALSE, ...) { if (all_null(variable, iteration, chain, draw)) { return(x) } + x <- repair_draws(x) variable <- check_existing_variables(variable, x, regex = regex) iteration <- check_iteration_ids(iteration, x, unique = unique) chain <- check_chain_ids(chain, x, unique = unique) draw <- check_draw_ids(draw, x, unique = unique) + + if (exclude) { + if (!is.null(variable)) { + variable <- setdiff(variables(x), variable) + } + if (!is.null(iteration)) { + iteration <- setdiff(iteration_ids(x), iteration) + } + if (!is.null(chain)) { + chain <- setdiff(chain_ids(x), chain) + } + if (!is.null(draw)) { + draw <- setdiff(draw_ids(x), draw) + } + } + x <- prepare_subsetting(x, iteration, chain, draw) if (!is.null(draw)) { iteration <- draw diff --git a/man/subset_draws.Rd b/man/subset_draws.Rd index 26d423d4..d50e3445 100644 --- a/man/subset_draws.Rd +++ b/man/subset_draws.Rd @@ -21,6 +21,7 @@ subset_draws(x, ...) draw = NULL, regex = FALSE, unique = TRUE, + exclude = FALSE, ... ) @@ -32,6 +33,7 @@ subset_draws(x, ...) draw = NULL, regex = FALSE, unique = TRUE, + exclude = FALSE, ... ) @@ -43,6 +45,7 @@ subset_draws(x, ...) draw = NULL, regex = FALSE, unique = TRUE, + exclude = FALSE, ... ) @@ -54,6 +57,7 @@ subset_draws(x, ...) draw = NULL, regex = FALSE, unique = TRUE, + exclude = FALSE, ... ) @@ -65,6 +69,7 @@ subset_draws(x, ...) draw = NULL, regex = FALSE, unique = TRUE, + exclude = FALSE, ... ) @@ -78,24 +83,31 @@ is defined.} \item{...}{Arguments passed to individual methods (if applicable).} -\item{variable}{(character vector) The variables to select. All elements of -non-scalar variables can be selected at once.} +\item{variable}{(character vector) The variables to select. All +elements of non-scalar variables can be selected at once.} \item{iteration}{(integer vector) The iteration indices to select.} \item{chain}{(integer vector) The chain indices to select.} -\item{draw}{(integer vector) The draw indices to be select. Subsetting draw -indices will lead to an automatic merging of chains via \code{\link{merge_chains}}.} +\item{draw}{(integer vector) The draw indices to be +select. Subsetting draw indices will lead to an automatic merging +of chains via \code{\link{merge_chains}}.} \item{regex}{(logical) Should \code{variable} should be treated as a -(vector of) regular expressions? Any variable in \code{x} matching at least one -of the regular expressions will be selected. Defaults to \code{FALSE}.} - -\item{unique}{(logical) Should duplicated selection of chains, iterations, or -draws be allowed? If \code{TRUE} (the default) only unique chains, iterations, -and draws are selected regardless of how often they appear in the -respective selecting arguments.} +(vector of) regular expressions? Any variable in \code{x} matching at +least one of the regular expressions will be selected. Defaults +to \code{FALSE}.} + +\item{unique}{(logical) Should duplicated selection of chains, +iterations, or draws be allowed? If \code{TRUE} (the default) only +unique chains, iterations, and draws are selected regardless of +how often they appear in the respective selecting arguments.} + +\item{exclude}{(logical) Should the matched selection be excluded? +If \code{FALSE} (the default) the matched subset of draws will be +returned. If \code{TRUE} the draws excluding the matched subset will +be returned.} } \value{ A \code{draws} object of the same class as \code{x}. From e00566cc69d079b20df273579729f1c5f4ed8fc1 Mon Sep 17 00:00:00 2001 From: n-kall Date: Thu, 11 Jan 2024 14:27:31 +0200 Subject: [PATCH 2/9] fix handling of reserved variables when subsetting with exclude --- R/subset_draws.R | 40 ++++++++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/R/subset_draws.R b/R/subset_draws.R index 650660d5..1b6b82c4 100644 --- a/R/subset_draws.R +++ b/R/subset_draws.R @@ -63,12 +63,20 @@ subset_draws.draws_matrix <- function(x, variable = NULL, iteration = NULL, draw <- check_draw_ids(draw, x, unique = unique) if (exclude) { - variable <- setdiff(variables(x), variable) - iteration <- setdiff(iteration_ids(x), iteration) - chain <- setdiff(chain_ids(x), chain) - draw <- setdiff(draw_ids(x), draw) + if (!is.null(variable)) { + variable <- setdiff(variables(x, reserved = TRUE), variable) + } + if (!is.null(iteration)) { + iteration <- setdiff(iteration_ids(x), iteration) + } + if (!is.null(chain)) { + chain <- setdiff(chain_ids(x), chain) + } + if (!is.null(draw)) { + draw <- setdiff(draw_ids(x), draw) + } } - + x <- prepare_subsetting(x, iteration, chain, draw) x <- .subset_draws(x, iteration, chain, draw, variable, reserved = TRUE) if (!is.null(chain) || !is.null(iteration)) { @@ -85,7 +93,7 @@ subset_draws.draws_array <- function(x, variable = NULL, iteration = NULL, if (all_null(variable, iteration, chain, draw)) { return(x) } - + x <- repair_draws(x) variable <- check_existing_variables(variable, x, regex = regex) iteration <- check_iteration_ids(iteration, x, unique = unique) @@ -94,7 +102,7 @@ subset_draws.draws_array <- function(x, variable = NULL, iteration = NULL, if (exclude) { if (!is.null(variable)) { - variable <- setdiff(variables(x), variable) + variable <- setdiff(variables(x, reserved = TRUE), variable) } if (!is.null(iteration)) { iteration <- setdiff(iteration_ids(x), iteration) @@ -106,7 +114,7 @@ subset_draws.draws_array <- function(x, variable = NULL, iteration = NULL, draw <- setdiff(draw_ids(x), draw) } } - + x <- prepare_subsetting(x, iteration, chain, draw) if (!is.null(draw)) { iteration <- draw @@ -126,7 +134,7 @@ subset_draws.draws_df <- function(x, variable = NULL, iteration = NULL, if (all_null(variable, iteration, chain, draw)) { return(x) } - + x <- repair_draws(x) unique <- as_one_logical(unique) variable <- check_existing_variables(variable, x, regex = regex) @@ -136,7 +144,7 @@ subset_draws.draws_df <- function(x, variable = NULL, iteration = NULL, if (exclude) { if (!is.null(variable)) { - variable <- setdiff(variables(x), variable) + variable <- setdiff(variables(x, reserved = TRUE), variable) } if (!is.null(iteration)) { iteration <- setdiff(iteration_ids(x), iteration) @@ -148,7 +156,7 @@ subset_draws.draws_df <- function(x, variable = NULL, iteration = NULL, draw <- setdiff(draw_ids(x), draw) } } - + x <- prepare_subsetting(x, iteration, chain, draw) x <- .subset_draws( x, iteration, chain, draw, variable, @@ -174,7 +182,7 @@ subset_draws.draws_list <- function(x, variable = NULL, iteration = NULL, if (exclude) { if (!is.null(variable)) { - variable <- setdiff(variables(x), variable) + variable <- setdiff(variables(x, reserved = TRUE), variable) } if (!is.null(iteration)) { iteration <- setdiff(iteration_ids(x), iteration) @@ -186,7 +194,7 @@ subset_draws.draws_list <- function(x, variable = NULL, iteration = NULL, draw <- setdiff(draw_ids(x), draw) } } - + x <- prepare_subsetting(x, iteration, chain, draw) if (!is.null(draw)) { iteration <- draw @@ -206,7 +214,7 @@ subset_draws.draws_rvars <- function(x, variable = NULL, iteration = NULL, if (all_null(variable, iteration, chain, draw)) { return(x) } - + x <- repair_draws(x) variable <- check_existing_variables(variable, x, regex = regex) iteration <- check_iteration_ids(iteration, x, unique = unique) @@ -215,7 +223,7 @@ subset_draws.draws_rvars <- function(x, variable = NULL, iteration = NULL, if (exclude) { if (!is.null(variable)) { - variable <- setdiff(variables(x), variable) + variable <- setdiff(variables(x, reserved = TRUE), variable) } if (!is.null(iteration)) { iteration <- setdiff(iteration_ids(x), iteration) @@ -227,7 +235,7 @@ subset_draws.draws_rvars <- function(x, variable = NULL, iteration = NULL, draw <- setdiff(draw_ids(x), draw) } } - + x <- prepare_subsetting(x, iteration, chain, draw) if (!is.null(draw)) { iteration <- draw From 95c1f01f6ca7a2e230ce267de89642f6083ea7b3 Mon Sep 17 00:00:00 2001 From: n-kall Date: Thu, 11 Jan 2024 14:28:12 +0200 Subject: [PATCH 3/9] add tests for subset_draws with exclude = TRUE --- tests/testthat/test-subset_draws.R | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/testthat/test-subset_draws.R b/tests/testthat/test-subset_draws.R index a5170ac4..0b7c4862 100644 --- a/tests/testthat/test-subset_draws.R +++ b/tests/testthat/test-subset_draws.R @@ -1,4 +1,5 @@ test_that("subset_draws works correctly for draws_matrix objects", { + x <- as_draws_matrix(example_draws()) x_sub <- subset_draws(x, variable = c("mu", "tau"), iteration = 5:10) x_sub2 <- x[c(5:10, 105:110, 205:210, 305:310), c("mu", "tau")] @@ -12,6 +13,11 @@ test_that("subset_draws works correctly for draws_matrix objects", { x <- weight_draws(x, rep(1, ndraws(x))) x_sub <- subset_draws(x, variable = "mu") expect_equal(variables(x_sub, reserved = TRUE), c("mu", ".log_weight")) + + x_sub <- subset_draws(x, variable = "mu", chain = c(1, 2, 3), exclude = TRUE) + expect_equal(setdiff(variables(x, reserved = TRUE), "mu"), variables(x_sub, reserved = TRUE)) + expect_equal(nchains(x_sub), 1) + }) test_that("subset_draws works correctly for draws_array objects", { @@ -35,6 +41,10 @@ test_that("subset_draws works correctly for draws_array objects", { x <- weight_draws(x, rep(1, ndraws(x))) x_sub <- subset_draws(x, variable = "mu") expect_equal(variables(x_sub, reserved = TRUE), c("mu", ".log_weight")) + + x_sub <- subset_draws(x, variable = "mu", chain = c(1, 2, 3), exclude = TRUE) + expect_equal(setdiff(variables(x, reserved = TRUE), "mu"), variables(x_sub, reserved = TRUE)) + expect_equal(nchains(x_sub), 1) }) test_that("subset_draws works correctly for draws_df objects", { @@ -51,6 +61,10 @@ test_that("subset_draws works correctly for draws_df objects", { x <- weight_draws(x, rep(1, ndraws(x))) x_sub <- subset_draws(x, variable = "mu") expect_equal(names(x_sub), c("mu", ".log_weight", ".chain", ".iteration", ".draw")) + + x_sub <- subset_draws(x, variable = "mu", chain = c(1, 2, 3), exclude = TRUE) + expect_equal(setdiff(variables(x, reserved = TRUE), "mu"), variables(x_sub, reserved = TRUE)) + expect_equal(nchains(x_sub), 1) }) test_that("subset_draws works correctly for draws_list objects", { @@ -73,6 +87,10 @@ test_that("subset_draws works correctly for draws_list objects", { x <- weight_draws(x, rep(1, ndraws(x))) x_sub <- subset_draws(x, variable = "mu") expect_equal(variables(x_sub, reserved = TRUE), c("mu", ".log_weight")) + + x_sub <- subset_draws(x, variable = "mu", chain = c(1, 2, 3), exclude = TRUE) + expect_equal(setdiff(variables(x, reserved = TRUE), "mu"), variables(x_sub, reserved = TRUE)) + expect_equal(nchains(x_sub), 1) }) test_that("subset_draws works correctly for draws_rvars objects", { @@ -95,6 +113,10 @@ test_that("subset_draws works correctly for draws_rvars objects", { x <- weight_draws(x, rep(1, ndraws(x))) x_sub <- subset_draws(x, variable = "mu") expect_equal(variables(x_sub, reserved = TRUE), c("mu", ".log_weight")) + + x_sub <- subset_draws(x, variable = "mu", chain = c(1, 2, 3), exclude = TRUE) + expect_equal(setdiff(variables(x, reserved = TRUE), "mu"), variables(x_sub, reserved = TRUE)) + expect_equal(nchains(x_sub), 1) }) test_that("subset_draws works correctly for rvar objects", { @@ -111,6 +133,11 @@ test_that("subset_draws works correctly for rvar objects", { "Merging chains in order to subset via 'draw'" ) expect_equal(niterations(x_sub), 3) + + x_sub <- subset_draws(x, iteration = c(1, 2), chain = c(1, 2, 3), exclude = TRUE) + expect_equal(niterations(x_sub), niterations(x) - 2) + expect_equal(nchains(x_sub), 1) + }) test_that("variables can be subsetted via regular expressions", { From e1624cbfc9f77ee6113a6d52632a603370423116 Mon Sep 17 00:00:00 2001 From: n-kall Date: Thu, 11 Jan 2024 15:38:11 +0200 Subject: [PATCH 4/9] improve subset_draws exclude documentation --- R/subset_draws.R | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/R/subset_draws.R b/R/subset_draws.R index 1b6b82c4..87486650 100644 --- a/R/subset_draws.R +++ b/R/subset_draws.R @@ -18,10 +18,10 @@ #' iterations, or draws be allowed? If `TRUE` (the default) only #' unique chains, iterations, and draws are selected regardless of #' how often they appear in the respective selecting arguments. -#' @param exclude (logical) Should the matched selection be excluded? -#' If `FALSE` (the default) the matched subset of draws will be -#' returned. If `TRUE` the draws excluding the matched subset will -#' be returned. +#' @param exclude (logical) Should the selected subset be excluded? +#' If `FALSE` (the default) the selection of draws will be returned. +#' If `TRUE` all draws excluding the selected subset will be +#' returned. #' #' @template args-methods-dots #' @template return-draws From f46f75b277702cb6c121cb0fda64d2c0b1e31ee8 Mon Sep 17 00:00:00 2001 From: n-kall Date: Thu, 11 Jan 2024 15:38:55 +0200 Subject: [PATCH 5/9] documentation update --- man/subset_draws.Rd | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/man/subset_draws.Rd b/man/subset_draws.Rd index d50e3445..e8d1e82b 100644 --- a/man/subset_draws.Rd +++ b/man/subset_draws.Rd @@ -104,10 +104,10 @@ iterations, or draws be allowed? If \code{TRUE} (the default) only unique chains, iterations, and draws are selected regardless of how often they appear in the respective selecting arguments.} -\item{exclude}{(logical) Should the matched selection be excluded? -If \code{FALSE} (the default) the matched subset of draws will be -returned. If \code{TRUE} the draws excluding the matched subset will -be returned.} +\item{exclude}{(logical) Should the selected subset be excluded? +If \code{FALSE} (the default) the selection of draws will be returned. +If \code{TRUE} all draws excluding the selected subset will be +returned.} } \value{ A \code{draws} object of the same class as \code{x}. From eab298de2f417a6560c049adc65ddbe1361348ca Mon Sep 17 00:00:00 2001 From: n-kall Date: Fri, 12 Jan 2024 10:51:53 +0200 Subject: [PATCH 6/9] move excluding to check_* functions --- R/draws-index.R | 36 +++++++++++++-- R/subset_draws.R | 115 +++++++++-------------------------------------- 2 files changed, 52 insertions(+), 99 deletions(-) diff --git a/R/draws-index.R b/R/draws-index.R index 8cdfa990..748ba983 100644 --- a/R/draws-index.R +++ b/R/draws-index.R @@ -478,15 +478,18 @@ ndraws.rvar <- function(x) { # @param regex should 'variables' be treated as regular expressions? # @param scalar_only should only scalar variables be matched? check_existing_variables <- function(variables, x, regex = FALSE, - scalar_only = FALSE) { + scalar_only = FALSE, exclude = FALSE) { check_draws_object(x) if (is.null(variables)) { return(NULL) } + regex <- as_one_logical(regex) scalar_only <- as_one_logical(scalar_only) + exclude <- as_one_logical(exclude) variables <- unique(as.character(variables)) all_variables <- variables(x, reserved = TRUE) + if (regex) { tmp <- named_list(variables) for (i in seq_along(variables)) { @@ -529,6 +532,12 @@ check_existing_variables <- function(variables, x, regex = FALSE, stop_no_call("The following variables are missing in the draws object: ", comma(missing_variables)) } + + # handle excluding variables for subset_draws + if (exclude) { + variables <- setdiff(all_variables, variables) + } + invisible(variables) } @@ -564,12 +573,13 @@ check_reserved_variables <- function(variables) { # check validity of iteration indices # @param unique should the returned IDs be unique? -check_iteration_ids <- function(iteration_ids, x, unique = TRUE) { +check_iteration_ids <- function(iteration_ids, x, unique = TRUE, exclude = FALSE) { check_draws_object(x) if (is.null(iteration_ids)) { return(NULL) } unique <- as_one_logical(unique) + exclude <- as_one_logical(exclude) iteration_ids <- as.integer(iteration_ids) if (unique) { iteration_ids <- unique(iteration_ids) @@ -584,17 +594,24 @@ check_iteration_ids <- function(iteration_ids, x, unique = TRUE) { stop_no_call("Tried to subset iterations up to '", max_iteration, "' ", "but the object only has '", niterations, "' iterations.") } + + # handle exclude iterations in subset_draws + if (exclude) { + iteration_ids <- setdiff(iteration_ids(x), iteration_ids) + } + invisible(iteration_ids) } # check validity of chain indices # @param unique should the returned IDs be unique? -check_chain_ids <- function(chain_ids, x, unique = TRUE) { +check_chain_ids <- function(chain_ids, x, unique = TRUE, exclude = FALSE) { check_draws_object(x) if (is.null(chain_ids)) { return(NULL) } unique <- as_one_logical(unique) + exclude <- as_one_logical(exclude) chain_ids <- as.integer(chain_ids) if (unique) { chain_ids <- unique(chain_ids) @@ -609,17 +626,23 @@ check_chain_ids <- function(chain_ids, x, unique = TRUE) { stop_no_call("Tried to subset chains up to '", max_chain, "' ", "but the object only has '", nchains, "' chains.") } + + if (exclude) { + chain_ids <- setdiff(chain_ids(x), chain_ids) + } + invisible(chain_ids) } # check validity of draw indices # @param unique should the returned IDs be unique? -check_draw_ids <- function(draw_ids, x, unique = TRUE) { +check_draw_ids <- function(draw_ids, x, unique = TRUE, exclude = FALSE) { check_draws_object(x) if (is.null(draw_ids)) { return(NULL) } unique <- as_one_logical(unique) + exclude <- as_one_logical(exclude) draw_ids <- as.integer(draw_ids) if (unique) { draw_ids <- unique(draw_ids) @@ -634,5 +657,10 @@ check_draw_ids <- function(draw_ids, x, unique = TRUE) { stop_no_call("Tried to subset draws up to '", max_draw, "' ", "but the object only has '", ndraws, "' draws.") } + + if (exclude) { + draw_ids <- setdif(draw_ids(x), draw_ids) + } + invisible(draw_ids) } diff --git a/R/subset_draws.R b/R/subset_draws.R index 87486650..503b673b 100644 --- a/R/subset_draws.R +++ b/R/subset_draws.R @@ -57,25 +57,10 @@ subset_draws.draws_matrix <- function(x, variable = NULL, iteration = NULL, return(x) } x <- repair_draws(x) - variable <- check_existing_variables(variable, x, regex = regex) - iteration <- check_iteration_ids(iteration, x, unique = unique) - chain <- check_chain_ids(chain, x, unique = unique) - draw <- check_draw_ids(draw, x, unique = unique) - - if (exclude) { - if (!is.null(variable)) { - variable <- setdiff(variables(x, reserved = TRUE), variable) - } - if (!is.null(iteration)) { - iteration <- setdiff(iteration_ids(x), iteration) - } - if (!is.null(chain)) { - chain <- setdiff(chain_ids(x), chain) - } - if (!is.null(draw)) { - draw <- setdiff(draw_ids(x), draw) - } - } + variable <- check_existing_variables(variable, x, regex = regex, exclude = exclude) + iteration <- check_iteration_ids(iteration, x, unique = unique, exclude = exclude) + chain <- check_chain_ids(chain, x, unique = unique, exclude = exclude) + draw <- check_draw_ids(draw, x, unique = unique, exclude = exclude) x <- prepare_subsetting(x, iteration, chain, draw) x <- .subset_draws(x, iteration, chain, draw, variable, reserved = TRUE) @@ -95,25 +80,10 @@ subset_draws.draws_array <- function(x, variable = NULL, iteration = NULL, } x <- repair_draws(x) - variable <- check_existing_variables(variable, x, regex = regex) - iteration <- check_iteration_ids(iteration, x, unique = unique) - chain <- check_chain_ids(chain, x, unique = unique) - draw <- check_draw_ids(draw, x, unique = unique) - - if (exclude) { - if (!is.null(variable)) { - variable <- setdiff(variables(x, reserved = TRUE), variable) - } - if (!is.null(iteration)) { - iteration <- setdiff(iteration_ids(x), iteration) - } - if (!is.null(chain)) { - chain <- setdiff(chain_ids(x), chain) - } - if (!is.null(draw)) { - draw <- setdiff(draw_ids(x), draw) - } - } + variable <- check_existing_variables(variable, x, regex = regex, exclude = exclude) + iteration <- check_iteration_ids(iteration, x, unique = unique, exclude = exclude) + chain <- check_chain_ids(chain, x, unique = unique, exclude = exclude) + draw <- check_draw_ids(draw, x, unique = unique, exclude = exclude) x <- prepare_subsetting(x, iteration, chain, draw) if (!is.null(draw)) { @@ -137,25 +107,10 @@ subset_draws.draws_df <- function(x, variable = NULL, iteration = NULL, x <- repair_draws(x) unique <- as_one_logical(unique) - variable <- check_existing_variables(variable, x, regex = regex) - iteration <- check_iteration_ids(iteration, x, unique = unique) - chain <- check_chain_ids(chain, x, unique = unique) - draw <- check_draw_ids(draw, x, unique = unique) - - if (exclude) { - if (!is.null(variable)) { - variable <- setdiff(variables(x, reserved = TRUE), variable) - } - if (!is.null(iteration)) { - iteration <- setdiff(iteration_ids(x), iteration) - } - if (!is.null(chain)) { - chain <- setdiff(chain_ids(x), chain) - } - if (!is.null(draw)) { - draw <- setdiff(draw_ids(x), draw) - } - } + variable <- check_existing_variables(variable, x, regex = regex, exclude= exclude) + iteration <- check_iteration_ids(iteration, x, unique = unique, exclude= exclude) + chain <- check_chain_ids(chain, x, unique = unique, exclude= exclude) + draw <- check_draw_ids(draw, x, unique = unique, exclude= exclude) x <- prepare_subsetting(x, iteration, chain, draw) x <- .subset_draws( @@ -175,25 +130,10 @@ subset_draws.draws_list <- function(x, variable = NULL, iteration = NULL, } x <- repair_draws(x) - variable <- check_existing_variables(variable, x, regex = regex) - iteration <- check_iteration_ids(iteration, x, unique = unique) - chain <- check_chain_ids(chain, x, unique = unique) - draw <- check_draw_ids(draw, x, unique = unique) - - if (exclude) { - if (!is.null(variable)) { - variable <- setdiff(variables(x, reserved = TRUE), variable) - } - if (!is.null(iteration)) { - iteration <- setdiff(iteration_ids(x), iteration) - } - if (!is.null(chain)) { - chain <- setdiff(chain_ids(x), chain) - } - if (!is.null(draw)) { - draw <- setdiff(draw_ids(x), draw) - } - } + variable <- check_existing_variables(variable, x, regex = regex, exclude = exclude) + iteration <- check_iteration_ids(iteration, x, unique = unique, exclude = exclude) + chain <- check_chain_ids(chain, x, unique = unique, exclude = exclude) + draw <- check_draw_ids(draw, x, unique = unique, exclude = exclude) x <- prepare_subsetting(x, iteration, chain, draw) if (!is.null(draw)) { @@ -216,25 +156,10 @@ subset_draws.draws_rvars <- function(x, variable = NULL, iteration = NULL, } x <- repair_draws(x) - variable <- check_existing_variables(variable, x, regex = regex) - iteration <- check_iteration_ids(iteration, x, unique = unique) - chain <- check_chain_ids(chain, x, unique = unique) - draw <- check_draw_ids(draw, x, unique = unique) - - if (exclude) { - if (!is.null(variable)) { - variable <- setdiff(variables(x, reserved = TRUE), variable) - } - if (!is.null(iteration)) { - iteration <- setdiff(iteration_ids(x), iteration) - } - if (!is.null(chain)) { - chain <- setdiff(chain_ids(x), chain) - } - if (!is.null(draw)) { - draw <- setdiff(draw_ids(x), draw) - } - } + variable <- check_existing_variables(variable, x, regex = regex, exclude = exclude) + iteration <- check_iteration_ids(iteration, x, unique = unique, exclude= exclude) + chain <- check_chain_ids(chain, x, unique = unique, exclude= exclude) + draw <- check_draw_ids(draw, x, unique = unique, exclude= exclude) x <- prepare_subsetting(x, iteration, chain, draw) if (!is.null(draw)) { From 2bcd48f13d8daf372f01e71ef835a9e90b4f1690 Mon Sep 17 00:00:00 2001 From: n-kall Date: Fri, 12 Jan 2024 10:54:42 +0200 Subject: [PATCH 7/9] modify subset_draws exclusion doc --- R/subset_draws.R | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/R/subset_draws.R b/R/subset_draws.R index 503b673b..f57e7d8a 100644 --- a/R/subset_draws.R +++ b/R/subset_draws.R @@ -19,9 +19,8 @@ #' unique chains, iterations, and draws are selected regardless of #' how often they appear in the respective selecting arguments. #' @param exclude (logical) Should the selected subset be excluded? -#' If `FALSE` (the default) the selection of draws will be returned. -#' If `TRUE` all draws excluding the selected subset will be -#' returned. +#' If `FALSE` (the default) the selection will be returned. If +#' `TRUE` all but the selected subset will be returned. #' #' @template args-methods-dots #' @template return-draws From 3905aa19ca21500e0153efc15d02fb4036b41b00 Mon Sep 17 00:00:00 2001 From: n-kall Date: Fri, 12 Jan 2024 12:13:57 +0200 Subject: [PATCH 8/9] add more tests for subset with exclude --- R/draws-index.R | 2 +- tests/testthat/test-subset_draws.R | 37 ++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/R/draws-index.R b/R/draws-index.R index 748ba983..5f1109eb 100644 --- a/R/draws-index.R +++ b/R/draws-index.R @@ -659,7 +659,7 @@ check_draw_ids <- function(draw_ids, x, unique = TRUE, exclude = FALSE) { } if (exclude) { - draw_ids <- setdif(draw_ids(x), draw_ids) + draw_ids <- setdiff(draw_ids(x), draw_ids) } invisible(draw_ids) diff --git a/tests/testthat/test-subset_draws.R b/tests/testthat/test-subset_draws.R index 0b7c4862..bae8bc2f 100644 --- a/tests/testthat/test-subset_draws.R +++ b/tests/testthat/test-subset_draws.R @@ -18,6 +18,13 @@ test_that("subset_draws works correctly for draws_matrix objects", { expect_equal(setdiff(variables(x, reserved = TRUE), "mu"), variables(x_sub, reserved = TRUE)) expect_equal(nchains(x_sub), 1) + x_sub <- subset_draws(x, draw = c(1, 2, 3), exclude = TRUE) + expect_equal(ndraws(x) - 3, ndraws(x_sub)) + + x_sub <- subset_draws(x, iteration = c(1, 2, 3), exclude = TRUE) + expect_equal(ndraws(x) - 3 * nchains(x), ndraws(x_sub)) + + }) test_that("subset_draws works correctly for draws_array objects", { @@ -45,6 +52,12 @@ test_that("subset_draws works correctly for draws_array objects", { x_sub <- subset_draws(x, variable = "mu", chain = c(1, 2, 3), exclude = TRUE) expect_equal(setdiff(variables(x, reserved = TRUE), "mu"), variables(x_sub, reserved = TRUE)) expect_equal(nchains(x_sub), 1) + + x_sub <- subset_draws(x, draw = c(1, 2, 3), exclude = TRUE) + expect_equal(ndraws(x) - 3, ndraws(x_sub)) + + x_sub <- subset_draws(x, iteration = c(1, 2, 3), exclude = TRUE) + expect_equal(ndraws(x) - 3 * nchains(x), ndraws(x_sub)) }) test_that("subset_draws works correctly for draws_df objects", { @@ -65,6 +78,12 @@ test_that("subset_draws works correctly for draws_df objects", { x_sub <- subset_draws(x, variable = "mu", chain = c(1, 2, 3), exclude = TRUE) expect_equal(setdiff(variables(x, reserved = TRUE), "mu"), variables(x_sub, reserved = TRUE)) expect_equal(nchains(x_sub), 1) + + x_sub <- subset_draws(x, draw = c(1, 2, 3), exclude = TRUE) + expect_equal(ndraws(x) - 3, ndraws(x_sub)) + + x_sub <- subset_draws(x, iteration = c(1, 2, 3), exclude = TRUE) + expect_equal(ndraws(x) - 3 * nchains(x), ndraws(x_sub)) }) test_that("subset_draws works correctly for draws_list objects", { @@ -91,6 +110,12 @@ test_that("subset_draws works correctly for draws_list objects", { x_sub <- subset_draws(x, variable = "mu", chain = c(1, 2, 3), exclude = TRUE) expect_equal(setdiff(variables(x, reserved = TRUE), "mu"), variables(x_sub, reserved = TRUE)) expect_equal(nchains(x_sub), 1) + + x_sub <- subset_draws(x, draw = c(1, 2, 3), exclude = TRUE) + expect_equal(ndraws(x) - 3, ndraws(x_sub)) + + x_sub <- subset_draws(x, iteration = c(1, 2, 3), exclude = TRUE) + expect_equal(ndraws(x) - 3 * nchains(x), ndraws(x_sub)) }) test_that("subset_draws works correctly for draws_rvars objects", { @@ -117,6 +142,12 @@ test_that("subset_draws works correctly for draws_rvars objects", { x_sub <- subset_draws(x, variable = "mu", chain = c(1, 2, 3), exclude = TRUE) expect_equal(setdiff(variables(x, reserved = TRUE), "mu"), variables(x_sub, reserved = TRUE)) expect_equal(nchains(x_sub), 1) + + x_sub <- subset_draws(x, draw = c(1, 2, 3), exclude = TRUE) + expect_equal(ndraws(x) - 3, ndraws(x_sub)) + + x_sub <- subset_draws(x, iteration = c(1, 2, 3), exclude = TRUE) + expect_equal(ndraws(x) - 3 * nchains(x), ndraws(x_sub)) }) test_that("subset_draws works correctly for rvar objects", { @@ -138,6 +169,12 @@ test_that("subset_draws works correctly for rvar objects", { expect_equal(niterations(x_sub), niterations(x) - 2) expect_equal(nchains(x_sub), 1) + x_sub <- subset_draws(x, draw = c(1, 2, 3), exclude = TRUE) + expect_equal(ndraws(x) - 3, ndraws(x_sub)) + + x_sub <- subset_draws(x, iteration = c(1, 2, 3), exclude = TRUE) + expect_equal(ndraws(x) - 3 * nchains(x), ndraws(x_sub)) + }) test_that("variables can be subsetted via regular expressions", { From f919ba114b8598c5a141d8b2e575c11554ee0c3c Mon Sep 17 00:00:00 2001 From: n-kall Date: Fri, 12 Jan 2024 12:14:30 +0200 Subject: [PATCH 9/9] improve doc for subset with exclude --- R/subset_draws.R | 5 +++-- man/subset_draws.Rd | 5 ++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/R/subset_draws.R b/R/subset_draws.R index f57e7d8a..7802c63f 100644 --- a/R/subset_draws.R +++ b/R/subset_draws.R @@ -19,8 +19,9 @@ #' unique chains, iterations, and draws are selected regardless of #' how often they appear in the respective selecting arguments. #' @param exclude (logical) Should the selected subset be excluded? -#' If `FALSE` (the default) the selection will be returned. If -#' `TRUE` all but the selected subset will be returned. +#' If `FALSE` (the default) only the selected subset will be +#' returned. If `TRUE` everything but the selected subset will be +#' returned. #' #' @template args-methods-dots #' @template return-draws diff --git a/man/subset_draws.Rd b/man/subset_draws.Rd index e8d1e82b..c5698e08 100644 --- a/man/subset_draws.Rd +++ b/man/subset_draws.Rd @@ -105,9 +105,8 @@ unique chains, iterations, and draws are selected regardless of how often they appear in the respective selecting arguments.} \item{exclude}{(logical) Should the selected subset be excluded? -If \code{FALSE} (the default) the selection of draws will be returned. -If \code{TRUE} all draws excluding the selected subset will be -returned.} +If \code{FALSE} (the default) the selection will be returned. If +\code{TRUE} all but the selected subset will be returned.} } \value{ A \code{draws} object of the same class as \code{x}.