Skip to content

Commit

Permalink
move excluding to check_* functions
Browse files Browse the repository at this point in the history
  • Loading branch information
n-kall committed Jan 12, 2024
1 parent f46f75b commit eab298d
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 99 deletions.
36 changes: 32 additions & 4 deletions R/draws-index.R
Original file line number Diff line number Diff line change
Expand Up @@ -478,15 +478,18 @@ ndraws.rvar <- function(x) {
# @param regex should 'variables' be treated as regular expressions?
# @param scalar_only should only scalar variables be matched?
check_existing_variables <- function(variables, x, regex = FALSE,
scalar_only = FALSE) {
scalar_only = FALSE, exclude = FALSE) {
check_draws_object(x)
if (is.null(variables)) {
return(NULL)
}

regex <- as_one_logical(regex)
scalar_only <- as_one_logical(scalar_only)
exclude <- as_one_logical(exclude)
variables <- unique(as.character(variables))
all_variables <- variables(x, reserved = TRUE)

if (regex) {
tmp <- named_list(variables)
for (i in seq_along(variables)) {
Expand Down Expand Up @@ -529,6 +532,12 @@ check_existing_variables <- function(variables, x, regex = FALSE,
stop_no_call("The following variables are missing in the draws object: ",
comma(missing_variables))
}

# handle excluding variables for subset_draws
if (exclude) {
variables <- setdiff(all_variables, variables)
}

invisible(variables)
}

Expand Down Expand Up @@ -564,12 +573,13 @@ check_reserved_variables <- function(variables) {

# check validity of iteration indices
# @param unique should the returned IDs be unique?
check_iteration_ids <- function(iteration_ids, x, unique = TRUE) {
check_iteration_ids <- function(iteration_ids, x, unique = TRUE, exclude = FALSE) {
check_draws_object(x)
if (is.null(iteration_ids)) {
return(NULL)
}
unique <- as_one_logical(unique)
exclude <- as_one_logical(exclude)
iteration_ids <- as.integer(iteration_ids)
if (unique) {
iteration_ids <- unique(iteration_ids)
Expand All @@ -584,17 +594,24 @@ check_iteration_ids <- function(iteration_ids, x, unique = TRUE) {
stop_no_call("Tried to subset iterations up to '", max_iteration, "' ",
"but the object only has '", niterations, "' iterations.")
}

# handle exclude iterations in subset_draws
if (exclude) {
iteration_ids <- setdiff(iteration_ids(x), iteration_ids)
}

invisible(iteration_ids)
}

# check validity of chain indices
# @param unique should the returned IDs be unique?
check_chain_ids <- function(chain_ids, x, unique = TRUE) {
check_chain_ids <- function(chain_ids, x, unique = TRUE, exclude = FALSE) {
check_draws_object(x)
if (is.null(chain_ids)) {
return(NULL)
}
unique <- as_one_logical(unique)
exclude <- as_one_logical(exclude)
chain_ids <- as.integer(chain_ids)
if (unique) {
chain_ids <- unique(chain_ids)
Expand All @@ -609,17 +626,23 @@ check_chain_ids <- function(chain_ids, x, unique = TRUE) {
stop_no_call("Tried to subset chains up to '", max_chain, "' ",
"but the object only has '", nchains, "' chains.")
}

if (exclude) {
chain_ids <- setdiff(chain_ids(x), chain_ids)
}

invisible(chain_ids)
}

# check validity of draw indices
# @param unique should the returned IDs be unique?
check_draw_ids <- function(draw_ids, x, unique = TRUE) {
check_draw_ids <- function(draw_ids, x, unique = TRUE, exclude = FALSE) {
check_draws_object(x)
if (is.null(draw_ids)) {
return(NULL)
}
unique <- as_one_logical(unique)
exclude <- as_one_logical(exclude)
draw_ids <- as.integer(draw_ids)
if (unique) {
draw_ids <- unique(draw_ids)
Expand All @@ -634,5 +657,10 @@ check_draw_ids <- function(draw_ids, x, unique = TRUE) {
stop_no_call("Tried to subset draws up to '", max_draw, "' ",
"but the object only has '", ndraws, "' draws.")
}

if (exclude) {
draw_ids <- setdif(draw_ids(x), draw_ids)
}

invisible(draw_ids)
}
115 changes: 20 additions & 95 deletions R/subset_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,25 +57,10 @@ subset_draws.draws_matrix <- function(x, variable = NULL, iteration = NULL,
return(x)
}
x <- repair_draws(x)
variable <- check_existing_variables(variable, x, regex = regex)
iteration <- check_iteration_ids(iteration, x, unique = unique)
chain <- check_chain_ids(chain, x, unique = unique)
draw <- check_draw_ids(draw, x, unique = unique)

if (exclude) {
if (!is.null(variable)) {
variable <- setdiff(variables(x, reserved = TRUE), variable)
}
if (!is.null(iteration)) {
iteration <- setdiff(iteration_ids(x), iteration)
}
if (!is.null(chain)) {
chain <- setdiff(chain_ids(x), chain)
}
if (!is.null(draw)) {
draw <- setdiff(draw_ids(x), draw)
}
}
variable <- check_existing_variables(variable, x, regex = regex, exclude = exclude)
iteration <- check_iteration_ids(iteration, x, unique = unique, exclude = exclude)
chain <- check_chain_ids(chain, x, unique = unique, exclude = exclude)
draw <- check_draw_ids(draw, x, unique = unique, exclude = exclude)

x <- prepare_subsetting(x, iteration, chain, draw)
x <- .subset_draws(x, iteration, chain, draw, variable, reserved = TRUE)
Expand All @@ -95,25 +80,10 @@ subset_draws.draws_array <- function(x, variable = NULL, iteration = NULL,
}

x <- repair_draws(x)
variable <- check_existing_variables(variable, x, regex = regex)
iteration <- check_iteration_ids(iteration, x, unique = unique)
chain <- check_chain_ids(chain, x, unique = unique)
draw <- check_draw_ids(draw, x, unique = unique)

if (exclude) {
if (!is.null(variable)) {
variable <- setdiff(variables(x, reserved = TRUE), variable)
}
if (!is.null(iteration)) {
iteration <- setdiff(iteration_ids(x), iteration)
}
if (!is.null(chain)) {
chain <- setdiff(chain_ids(x), chain)
}
if (!is.null(draw)) {
draw <- setdiff(draw_ids(x), draw)
}
}
variable <- check_existing_variables(variable, x, regex = regex, exclude = exclude)
iteration <- check_iteration_ids(iteration, x, unique = unique, exclude = exclude)
chain <- check_chain_ids(chain, x, unique = unique, exclude = exclude)
draw <- check_draw_ids(draw, x, unique = unique, exclude = exclude)

x <- prepare_subsetting(x, iteration, chain, draw)
if (!is.null(draw)) {
Expand All @@ -137,25 +107,10 @@ subset_draws.draws_df <- function(x, variable = NULL, iteration = NULL,

x <- repair_draws(x)
unique <- as_one_logical(unique)
variable <- check_existing_variables(variable, x, regex = regex)
iteration <- check_iteration_ids(iteration, x, unique = unique)
chain <- check_chain_ids(chain, x, unique = unique)
draw <- check_draw_ids(draw, x, unique = unique)

if (exclude) {
if (!is.null(variable)) {
variable <- setdiff(variables(x, reserved = TRUE), variable)
}
if (!is.null(iteration)) {
iteration <- setdiff(iteration_ids(x), iteration)
}
if (!is.null(chain)) {
chain <- setdiff(chain_ids(x), chain)
}
if (!is.null(draw)) {
draw <- setdiff(draw_ids(x), draw)
}
}
variable <- check_existing_variables(variable, x, regex = regex, exclude= exclude)
iteration <- check_iteration_ids(iteration, x, unique = unique, exclude= exclude)
chain <- check_chain_ids(chain, x, unique = unique, exclude= exclude)
draw <- check_draw_ids(draw, x, unique = unique, exclude= exclude)

x <- prepare_subsetting(x, iteration, chain, draw)
x <- .subset_draws(
Expand All @@ -175,25 +130,10 @@ subset_draws.draws_list <- function(x, variable = NULL, iteration = NULL,
}

x <- repair_draws(x)
variable <- check_existing_variables(variable, x, regex = regex)
iteration <- check_iteration_ids(iteration, x, unique = unique)
chain <- check_chain_ids(chain, x, unique = unique)
draw <- check_draw_ids(draw, x, unique = unique)

if (exclude) {
if (!is.null(variable)) {
variable <- setdiff(variables(x, reserved = TRUE), variable)
}
if (!is.null(iteration)) {
iteration <- setdiff(iteration_ids(x), iteration)
}
if (!is.null(chain)) {
chain <- setdiff(chain_ids(x), chain)
}
if (!is.null(draw)) {
draw <- setdiff(draw_ids(x), draw)
}
}
variable <- check_existing_variables(variable, x, regex = regex, exclude = exclude)
iteration <- check_iteration_ids(iteration, x, unique = unique, exclude = exclude)
chain <- check_chain_ids(chain, x, unique = unique, exclude = exclude)
draw <- check_draw_ids(draw, x, unique = unique, exclude = exclude)

x <- prepare_subsetting(x, iteration, chain, draw)
if (!is.null(draw)) {
Expand All @@ -216,25 +156,10 @@ subset_draws.draws_rvars <- function(x, variable = NULL, iteration = NULL,
}

x <- repair_draws(x)
variable <- check_existing_variables(variable, x, regex = regex)
iteration <- check_iteration_ids(iteration, x, unique = unique)
chain <- check_chain_ids(chain, x, unique = unique)
draw <- check_draw_ids(draw, x, unique = unique)

if (exclude) {
if (!is.null(variable)) {
variable <- setdiff(variables(x, reserved = TRUE), variable)
}
if (!is.null(iteration)) {
iteration <- setdiff(iteration_ids(x), iteration)
}
if (!is.null(chain)) {
chain <- setdiff(chain_ids(x), chain)
}
if (!is.null(draw)) {
draw <- setdiff(draw_ids(x), draw)
}
}
variable <- check_existing_variables(variable, x, regex = regex, exclude = exclude)
iteration <- check_iteration_ids(iteration, x, unique = unique, exclude= exclude)
chain <- check_chain_ids(chain, x, unique = unique, exclude= exclude)
draw <- check_draw_ids(draw, x, unique = unique, exclude= exclude)

x <- prepare_subsetting(x, iteration, chain, draw)
if (!is.null(draw)) {
Expand Down

0 comments on commit eab298d

Please sign in to comment.