Skip to content

Commit

Permalink
add exclude argument to subset_draws
Browse files Browse the repository at this point in the history
  • Loading branch information
n-kall committed Jan 11, 2024
1 parent 4f4613c commit 43e6b4b
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 26 deletions.
112 changes: 97 additions & 15 deletions R/subset_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,25 @@
#' Subset [`draws`] objects by variables, iterations, chains, and draws indices.
#'
#' @template args-methods-x
#' @param variable (character vector) The variables to select. All elements of
#' non-scalar variables can be selected at once.
#' @param variable (character vector) The variables to select. All
#' elements of non-scalar variables can be selected at once.
#' @param iteration (integer vector) The iteration indices to select.
#' @param chain (integer vector) The chain indices to select.
#' @param draw (integer vector) The draw indices to be select. Subsetting draw
#' indices will lead to an automatic merging of chains via [`merge_chains`].
#' @param draw (integer vector) The draw indices to be
#' select. Subsetting draw indices will lead to an automatic merging
#' of chains via [`merge_chains`].
#' @param regex (logical) Should `variable` should be treated as a
#' (vector of) regular expressions? Any variable in `x` matching at least one
#' of the regular expressions will be selected. Defaults to `FALSE`.
#' @param unique (logical) Should duplicated selection of chains, iterations, or
#' draws be allowed? If `TRUE` (the default) only unique chains, iterations,
#' and draws are selected regardless of how often they appear in the
#' respective selecting arguments.
#' (vector of) regular expressions? Any variable in `x` matching at
#' least one of the regular expressions will be selected. Defaults
#' to `FALSE`.
#' @param unique (logical) Should duplicated selection of chains,
#' iterations, or draws be allowed? If `TRUE` (the default) only
#' unique chains, iterations, and draws are selected regardless of
#' how often they appear in the respective selecting arguments.
#' @param exclude (logical) Should the matched selection be excluded?
#' If `FALSE` (the default) the matched subset of draws will be
#' returned. If `TRUE` the draws excluding the matched subset will
#' be returned.
#'
#' @template args-methods-dots
#' @template return-draws
Expand Down Expand Up @@ -46,7 +52,7 @@ subset_draws <- function(x, ...) {
#' @export
subset_draws.draws_matrix <- function(x, variable = NULL, iteration = NULL,
chain = NULL, draw = NULL, regex = FALSE,
unique = TRUE, ...) {
unique = TRUE, exclude = FALSE, ...) {
if (all_null(variable, iteration, chain, draw)) {
return(x)
}
Expand All @@ -55,6 +61,14 @@ subset_draws.draws_matrix <- function(x, variable = NULL, iteration = NULL,
iteration <- check_iteration_ids(iteration, x, unique = unique)
chain <- check_chain_ids(chain, x, unique = unique)
draw <- check_draw_ids(draw, x, unique = unique)

if (exclude) {
variable <- setdiff(variables(x), variable)
iteration <- setdiff(iteration_ids(x), iteration)
chain <- setdiff(chain_ids(x), chain)
draw <- setdiff(draw_ids(x), draw)
}

x <- prepare_subsetting(x, iteration, chain, draw)
x <- .subset_draws(x, iteration, chain, draw, variable, reserved = TRUE)
if (!is.null(chain) || !is.null(iteration)) {
Expand All @@ -67,15 +81,32 @@ subset_draws.draws_matrix <- function(x, variable = NULL, iteration = NULL,
#' @export
subset_draws.draws_array <- function(x, variable = NULL, iteration = NULL,
chain = NULL, draw = NULL, regex = FALSE,
unique = TRUE, ...) {
unique = TRUE, exclude = FALSE, ...) {
if (all_null(variable, iteration, chain, draw)) {
return(x)
}

x <- repair_draws(x)
variable <- check_existing_variables(variable, x, regex = regex)
iteration <- check_iteration_ids(iteration, x, unique = unique)
chain <- check_chain_ids(chain, x, unique = unique)
draw <- check_draw_ids(draw, x, unique = unique)

if (exclude) {
if (!is.null(variable)) {
variable <- setdiff(variables(x), variable)
}
if (!is.null(iteration)) {
iteration <- setdiff(iteration_ids(x), iteration)
}
if (!is.null(chain)) {
chain <- setdiff(chain_ids(x), chain)
}
if (!is.null(draw)) {
draw <- setdiff(draw_ids(x), draw)
}
}

x <- prepare_subsetting(x, iteration, chain, draw)
if (!is.null(draw)) {
iteration <- draw
Expand All @@ -91,16 +122,33 @@ subset_draws.draws_array <- function(x, variable = NULL, iteration = NULL,
#' @export
subset_draws.draws_df <- function(x, variable = NULL, iteration = NULL,
chain = NULL, draw = NULL, regex = FALSE,
unique = TRUE, ...) {
unique = TRUE, exclude = FALSE, ...) {
if (all_null(variable, iteration, chain, draw)) {
return(x)
}

x <- repair_draws(x)
unique <- as_one_logical(unique)
variable <- check_existing_variables(variable, x, regex = regex)
iteration <- check_iteration_ids(iteration, x, unique = unique)
chain <- check_chain_ids(chain, x, unique = unique)
draw <- check_draw_ids(draw, x, unique = unique)

if (exclude) {
if (!is.null(variable)) {
variable <- setdiff(variables(x), variable)
}
if (!is.null(iteration)) {
iteration <- setdiff(iteration_ids(x), iteration)
}
if (!is.null(chain)) {
chain <- setdiff(chain_ids(x), chain)
}
if (!is.null(draw)) {
draw <- setdiff(draw_ids(x), draw)
}
}

x <- prepare_subsetting(x, iteration, chain, draw)
x <- .subset_draws(
x, iteration, chain, draw, variable,
Expand All @@ -113,15 +161,32 @@ subset_draws.draws_df <- function(x, variable = NULL, iteration = NULL,
#' @export
subset_draws.draws_list <- function(x, variable = NULL, iteration = NULL,
chain = NULL, draw = NULL, regex = FALSE,
unique = TRUE, ...) {
unique = TRUE, exclude = FALSE, ...) {
if (all_null(variable, iteration, chain, draw)) {
return(x)
}

x <- repair_draws(x)
variable <- check_existing_variables(variable, x, regex = regex)
iteration <- check_iteration_ids(iteration, x, unique = unique)
chain <- check_chain_ids(chain, x, unique = unique)
draw <- check_draw_ids(draw, x, unique = unique)

if (exclude) {
if (!is.null(variable)) {
variable <- setdiff(variables(x), variable)
}
if (!is.null(iteration)) {
iteration <- setdiff(iteration_ids(x), iteration)
}
if (!is.null(chain)) {
chain <- setdiff(chain_ids(x), chain)
}
if (!is.null(draw)) {
draw <- setdiff(draw_ids(x), draw)
}
}

x <- prepare_subsetting(x, iteration, chain, draw)
if (!is.null(draw)) {
iteration <- draw
Expand All @@ -137,15 +202,32 @@ subset_draws.draws_list <- function(x, variable = NULL, iteration = NULL,
#' @export
subset_draws.draws_rvars <- function(x, variable = NULL, iteration = NULL,
chain = NULL, draw = NULL, regex = FALSE,
unique = TRUE, ...) {
unique = TRUE, exclude = FALSE, ...) {
if (all_null(variable, iteration, chain, draw)) {
return(x)
}

x <- repair_draws(x)
variable <- check_existing_variables(variable, x, regex = regex)
iteration <- check_iteration_ids(iteration, x, unique = unique)
chain <- check_chain_ids(chain, x, unique = unique)
draw <- check_draw_ids(draw, x, unique = unique)

if (exclude) {
if (!is.null(variable)) {
variable <- setdiff(variables(x), variable)
}
if (!is.null(iteration)) {
iteration <- setdiff(iteration_ids(x), iteration)
}
if (!is.null(chain)) {
chain <- setdiff(chain_ids(x), chain)
}
if (!is.null(draw)) {
draw <- setdiff(draw_ids(x), draw)
}
}

x <- prepare_subsetting(x, iteration, chain, draw)
if (!is.null(draw)) {
iteration <- draw
Expand Down
34 changes: 23 additions & 11 deletions man/subset_draws.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 43e6b4b

Please sign in to comment.