Skip to content

Commit

Permalink
Merge pull request #336 from n-kall/exclude_draws
Browse files Browse the repository at this point in the history
Exclude draws
  • Loading branch information
paul-buerkner authored Jan 15, 2024
2 parents 98bfcbd + f919ba1 commit a63b894
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 50 deletions.
36 changes: 32 additions & 4 deletions R/draws-index.R
Original file line number Diff line number Diff line change
Expand Up @@ -478,15 +478,18 @@ ndraws.rvar <- function(x) {
# @param regex should 'variables' be treated as regular expressions?
# @param scalar_only should only scalar variables be matched?
check_existing_variables <- function(variables, x, regex = FALSE,
scalar_only = FALSE) {
scalar_only = FALSE, exclude = FALSE) {
check_draws_object(x)
if (is.null(variables)) {
return(NULL)
}

regex <- as_one_logical(regex)
scalar_only <- as_one_logical(scalar_only)
exclude <- as_one_logical(exclude)
variables <- unique(as.character(variables))
all_variables <- variables(x, reserved = TRUE)

if (regex) {
tmp <- named_list(variables)
for (i in seq_along(variables)) {
Expand Down Expand Up @@ -529,6 +532,12 @@ check_existing_variables <- function(variables, x, regex = FALSE,
stop_no_call("The following variables are missing in the draws object: ",
comma(missing_variables))
}

# handle excluding variables for subset_draws
if (exclude) {
variables <- setdiff(all_variables, variables)
}

invisible(variables)
}

Expand Down Expand Up @@ -564,12 +573,13 @@ check_reserved_variables <- function(variables) {

# check validity of iteration indices
# @param unique should the returned IDs be unique?
check_iteration_ids <- function(iteration_ids, x, unique = TRUE) {
check_iteration_ids <- function(iteration_ids, x, unique = TRUE, exclude = FALSE) {
check_draws_object(x)
if (is.null(iteration_ids)) {
return(NULL)
}
unique <- as_one_logical(unique)
exclude <- as_one_logical(exclude)
iteration_ids <- as.integer(iteration_ids)
if (unique) {
iteration_ids <- unique(iteration_ids)
Expand All @@ -584,17 +594,24 @@ check_iteration_ids <- function(iteration_ids, x, unique = TRUE) {
stop_no_call("Tried to subset iterations up to '", max_iteration, "' ",
"but the object only has '", niterations, "' iterations.")
}

# handle exclude iterations in subset_draws
if (exclude) {
iteration_ids <- setdiff(iteration_ids(x), iteration_ids)
}

invisible(iteration_ids)
}

# check validity of chain indices
# @param unique should the returned IDs be unique?
check_chain_ids <- function(chain_ids, x, unique = TRUE) {
check_chain_ids <- function(chain_ids, x, unique = TRUE, exclude = FALSE) {
check_draws_object(x)
if (is.null(chain_ids)) {
return(NULL)
}
unique <- as_one_logical(unique)
exclude <- as_one_logical(exclude)
chain_ids <- as.integer(chain_ids)
if (unique) {
chain_ids <- unique(chain_ids)
Expand All @@ -609,17 +626,23 @@ check_chain_ids <- function(chain_ids, x, unique = TRUE) {
stop_no_call("Tried to subset chains up to '", max_chain, "' ",
"but the object only has '", nchains, "' chains.")
}

if (exclude) {
chain_ids <- setdiff(chain_ids(x), chain_ids)
}

invisible(chain_ids)
}

# check validity of draw indices
# @param unique should the returned IDs be unique?
check_draw_ids <- function(draw_ids, x, unique = TRUE) {
check_draw_ids <- function(draw_ids, x, unique = TRUE, exclude = FALSE) {
check_draws_object(x)
if (is.null(draw_ids)) {
return(NULL)
}
unique <- as_one_logical(unique)
exclude <- as_one_logical(exclude)
draw_ids <- as.integer(draw_ids)
if (unique) {
draw_ids <- unique(draw_ids)
Expand All @@ -634,5 +657,10 @@ check_draw_ids <- function(draw_ids, x, unique = TRUE) {
stop_no_call("Tried to subset draws up to '", max_draw, "' ",
"but the object only has '", ndraws, "' draws.")
}

if (exclude) {
draw_ids <- setdiff(draw_ids(x), draw_ids)
}

invisible(draw_ids)
}
85 changes: 50 additions & 35 deletions R/subset_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,25 @@
#' Subset [`draws`] objects by variables, iterations, chains, and draws indices.
#'
#' @template args-methods-x
#' @param variable (character vector) The variables to select. All elements of
#' non-scalar variables can be selected at once.
#' @param variable (character vector) The variables to select. All
#' elements of non-scalar variables can be selected at once.
#' @param iteration (integer vector) The iteration indices to select.
#' @param chain (integer vector) The chain indices to select.
#' @param draw (integer vector) The draw indices to be select. Subsetting draw
#' indices will lead to an automatic merging of chains via [`merge_chains`].
#' @param draw (integer vector) The draw indices to be
#' select. Subsetting draw indices will lead to an automatic merging
#' of chains via [`merge_chains`].
#' @param regex (logical) Should `variable` should be treated as a
#' (vector of) regular expressions? Any variable in `x` matching at least one
#' of the regular expressions will be selected. Defaults to `FALSE`.
#' @param unique (logical) Should duplicated selection of chains, iterations, or
#' draws be allowed? If `TRUE` (the default) only unique chains, iterations,
#' and draws are selected regardless of how often they appear in the
#' respective selecting arguments.
#' (vector of) regular expressions? Any variable in `x` matching at
#' least one of the regular expressions will be selected. Defaults
#' to `FALSE`.
#' @param unique (logical) Should duplicated selection of chains,
#' iterations, or draws be allowed? If `TRUE` (the default) only
#' unique chains, iterations, and draws are selected regardless of
#' how often they appear in the respective selecting arguments.
#' @param exclude (logical) Should the selected subset be excluded?
#' If `FALSE` (the default) only the selected subset will be
#' returned. If `TRUE` everything but the selected subset will be
#' returned.
#'
#' @template args-methods-dots
#' @template return-draws
Expand Down Expand Up @@ -46,15 +52,16 @@ subset_draws <- function(x, ...) {
#' @export
subset_draws.draws_matrix <- function(x, variable = NULL, iteration = NULL,
chain = NULL, draw = NULL, regex = FALSE,
unique = TRUE, ...) {
unique = TRUE, exclude = FALSE, ...) {
if (all_null(variable, iteration, chain, draw)) {
return(x)
}
x <- repair_draws(x)
variable <- check_existing_variables(variable, x, regex = regex)
iteration <- check_iteration_ids(iteration, x, unique = unique)
chain <- check_chain_ids(chain, x, unique = unique)
draw <- check_draw_ids(draw, x, unique = unique)
variable <- check_existing_variables(variable, x, regex = regex, exclude = exclude)
iteration <- check_iteration_ids(iteration, x, unique = unique, exclude = exclude)
chain <- check_chain_ids(chain, x, unique = unique, exclude = exclude)
draw <- check_draw_ids(draw, x, unique = unique, exclude = exclude)

x <- prepare_subsetting(x, iteration, chain, draw)
x <- .subset_draws(x, iteration, chain, draw, variable, reserved = TRUE)
if (!is.null(chain) || !is.null(iteration)) {
Expand All @@ -67,15 +74,17 @@ subset_draws.draws_matrix <- function(x, variable = NULL, iteration = NULL,
#' @export
subset_draws.draws_array <- function(x, variable = NULL, iteration = NULL,
chain = NULL, draw = NULL, regex = FALSE,
unique = TRUE, ...) {
unique = TRUE, exclude = FALSE, ...) {
if (all_null(variable, iteration, chain, draw)) {
return(x)
}

x <- repair_draws(x)
variable <- check_existing_variables(variable, x, regex = regex)
iteration <- check_iteration_ids(iteration, x, unique = unique)
chain <- check_chain_ids(chain, x, unique = unique)
draw <- check_draw_ids(draw, x, unique = unique)
variable <- check_existing_variables(variable, x, regex = regex, exclude = exclude)
iteration <- check_iteration_ids(iteration, x, unique = unique, exclude = exclude)
chain <- check_chain_ids(chain, x, unique = unique, exclude = exclude)
draw <- check_draw_ids(draw, x, unique = unique, exclude = exclude)

x <- prepare_subsetting(x, iteration, chain, draw)
if (!is.null(draw)) {
iteration <- draw
Expand All @@ -91,16 +100,18 @@ subset_draws.draws_array <- function(x, variable = NULL, iteration = NULL,
#' @export
subset_draws.draws_df <- function(x, variable = NULL, iteration = NULL,
chain = NULL, draw = NULL, regex = FALSE,
unique = TRUE, ...) {
unique = TRUE, exclude = FALSE, ...) {
if (all_null(variable, iteration, chain, draw)) {
return(x)
}

x <- repair_draws(x)
unique <- as_one_logical(unique)
variable <- check_existing_variables(variable, x, regex = regex)
iteration <- check_iteration_ids(iteration, x, unique = unique)
chain <- check_chain_ids(chain, x, unique = unique)
draw <- check_draw_ids(draw, x, unique = unique)
variable <- check_existing_variables(variable, x, regex = regex, exclude= exclude)
iteration <- check_iteration_ids(iteration, x, unique = unique, exclude= exclude)
chain <- check_chain_ids(chain, x, unique = unique, exclude= exclude)
draw <- check_draw_ids(draw, x, unique = unique, exclude= exclude)

x <- prepare_subsetting(x, iteration, chain, draw)
x <- .subset_draws(
x, iteration, chain, draw, variable,
Expand All @@ -113,15 +124,17 @@ subset_draws.draws_df <- function(x, variable = NULL, iteration = NULL,
#' @export
subset_draws.draws_list <- function(x, variable = NULL, iteration = NULL,
chain = NULL, draw = NULL, regex = FALSE,
unique = TRUE, ...) {
unique = TRUE, exclude = FALSE, ...) {
if (all_null(variable, iteration, chain, draw)) {
return(x)
}

x <- repair_draws(x)
variable <- check_existing_variables(variable, x, regex = regex)
iteration <- check_iteration_ids(iteration, x, unique = unique)
chain <- check_chain_ids(chain, x, unique = unique)
draw <- check_draw_ids(draw, x, unique = unique)
variable <- check_existing_variables(variable, x, regex = regex, exclude = exclude)
iteration <- check_iteration_ids(iteration, x, unique = unique, exclude = exclude)
chain <- check_chain_ids(chain, x, unique = unique, exclude = exclude)
draw <- check_draw_ids(draw, x, unique = unique, exclude = exclude)

x <- prepare_subsetting(x, iteration, chain, draw)
if (!is.null(draw)) {
iteration <- draw
Expand All @@ -137,15 +150,17 @@ subset_draws.draws_list <- function(x, variable = NULL, iteration = NULL,
#' @export
subset_draws.draws_rvars <- function(x, variable = NULL, iteration = NULL,
chain = NULL, draw = NULL, regex = FALSE,
unique = TRUE, ...) {
unique = TRUE, exclude = FALSE, ...) {
if (all_null(variable, iteration, chain, draw)) {
return(x)
}

x <- repair_draws(x)
variable <- check_existing_variables(variable, x, regex = regex)
iteration <- check_iteration_ids(iteration, x, unique = unique)
chain <- check_chain_ids(chain, x, unique = unique)
draw <- check_draw_ids(draw, x, unique = unique)
variable <- check_existing_variables(variable, x, regex = regex, exclude = exclude)
iteration <- check_iteration_ids(iteration, x, unique = unique, exclude= exclude)
chain <- check_chain_ids(chain, x, unique = unique, exclude= exclude)
draw <- check_draw_ids(draw, x, unique = unique, exclude= exclude)

x <- prepare_subsetting(x, iteration, chain, draw)
if (!is.null(draw)) {
iteration <- draw
Expand Down
33 changes: 22 additions & 11 deletions man/subset_draws.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit a63b894

Please sign in to comment.