From eab298de2f417a6560c049adc65ddbe1361348ca Mon Sep 17 00:00:00 2001 From: n-kall Date: Fri, 12 Jan 2024 10:51:53 +0200 Subject: [PATCH] move excluding to check_* functions --- R/draws-index.R | 36 +++++++++++++-- R/subset_draws.R | 115 +++++++++-------------------------------------- 2 files changed, 52 insertions(+), 99 deletions(-) diff --git a/R/draws-index.R b/R/draws-index.R index 8cdfa990..748ba983 100644 --- a/R/draws-index.R +++ b/R/draws-index.R @@ -478,15 +478,18 @@ ndraws.rvar <- function(x) { # @param regex should 'variables' be treated as regular expressions? # @param scalar_only should only scalar variables be matched? check_existing_variables <- function(variables, x, regex = FALSE, - scalar_only = FALSE) { + scalar_only = FALSE, exclude = FALSE) { check_draws_object(x) if (is.null(variables)) { return(NULL) } + regex <- as_one_logical(regex) scalar_only <- as_one_logical(scalar_only) + exclude <- as_one_logical(exclude) variables <- unique(as.character(variables)) all_variables <- variables(x, reserved = TRUE) + if (regex) { tmp <- named_list(variables) for (i in seq_along(variables)) { @@ -529,6 +532,12 @@ check_existing_variables <- function(variables, x, regex = FALSE, stop_no_call("The following variables are missing in the draws object: ", comma(missing_variables)) } + + # handle excluding variables for subset_draws + if (exclude) { + variables <- setdiff(all_variables, variables) + } + invisible(variables) } @@ -564,12 +573,13 @@ check_reserved_variables <- function(variables) { # check validity of iteration indices # @param unique should the returned IDs be unique? -check_iteration_ids <- function(iteration_ids, x, unique = TRUE) { +check_iteration_ids <- function(iteration_ids, x, unique = TRUE, exclude = FALSE) { check_draws_object(x) if (is.null(iteration_ids)) { return(NULL) } unique <- as_one_logical(unique) + exclude <- as_one_logical(exclude) iteration_ids <- as.integer(iteration_ids) if (unique) { iteration_ids <- unique(iteration_ids) @@ -584,17 +594,24 @@ check_iteration_ids <- function(iteration_ids, x, unique = TRUE) { stop_no_call("Tried to subset iterations up to '", max_iteration, "' ", "but the object only has '", niterations, "' iterations.") } + + # handle exclude iterations in subset_draws + if (exclude) { + iteration_ids <- setdiff(iteration_ids(x), iteration_ids) + } + invisible(iteration_ids) } # check validity of chain indices # @param unique should the returned IDs be unique? -check_chain_ids <- function(chain_ids, x, unique = TRUE) { +check_chain_ids <- function(chain_ids, x, unique = TRUE, exclude = FALSE) { check_draws_object(x) if (is.null(chain_ids)) { return(NULL) } unique <- as_one_logical(unique) + exclude <- as_one_logical(exclude) chain_ids <- as.integer(chain_ids) if (unique) { chain_ids <- unique(chain_ids) @@ -609,17 +626,23 @@ check_chain_ids <- function(chain_ids, x, unique = TRUE) { stop_no_call("Tried to subset chains up to '", max_chain, "' ", "but the object only has '", nchains, "' chains.") } + + if (exclude) { + chain_ids <- setdiff(chain_ids(x), chain_ids) + } + invisible(chain_ids) } # check validity of draw indices # @param unique should the returned IDs be unique? -check_draw_ids <- function(draw_ids, x, unique = TRUE) { +check_draw_ids <- function(draw_ids, x, unique = TRUE, exclude = FALSE) { check_draws_object(x) if (is.null(draw_ids)) { return(NULL) } unique <- as_one_logical(unique) + exclude <- as_one_logical(exclude) draw_ids <- as.integer(draw_ids) if (unique) { draw_ids <- unique(draw_ids) @@ -634,5 +657,10 @@ check_draw_ids <- function(draw_ids, x, unique = TRUE) { stop_no_call("Tried to subset draws up to '", max_draw, "' ", "but the object only has '", ndraws, "' draws.") } + + if (exclude) { + draw_ids <- setdif(draw_ids(x), draw_ids) + } + invisible(draw_ids) } diff --git a/R/subset_draws.R b/R/subset_draws.R index 87486650..503b673b 100644 --- a/R/subset_draws.R +++ b/R/subset_draws.R @@ -57,25 +57,10 @@ subset_draws.draws_matrix <- function(x, variable = NULL, iteration = NULL, return(x) } x <- repair_draws(x) - variable <- check_existing_variables(variable, x, regex = regex) - iteration <- check_iteration_ids(iteration, x, unique = unique) - chain <- check_chain_ids(chain, x, unique = unique) - draw <- check_draw_ids(draw, x, unique = unique) - - if (exclude) { - if (!is.null(variable)) { - variable <- setdiff(variables(x, reserved = TRUE), variable) - } - if (!is.null(iteration)) { - iteration <- setdiff(iteration_ids(x), iteration) - } - if (!is.null(chain)) { - chain <- setdiff(chain_ids(x), chain) - } - if (!is.null(draw)) { - draw <- setdiff(draw_ids(x), draw) - } - } + variable <- check_existing_variables(variable, x, regex = regex, exclude = exclude) + iteration <- check_iteration_ids(iteration, x, unique = unique, exclude = exclude) + chain <- check_chain_ids(chain, x, unique = unique, exclude = exclude) + draw <- check_draw_ids(draw, x, unique = unique, exclude = exclude) x <- prepare_subsetting(x, iteration, chain, draw) x <- .subset_draws(x, iteration, chain, draw, variable, reserved = TRUE) @@ -95,25 +80,10 @@ subset_draws.draws_array <- function(x, variable = NULL, iteration = NULL, } x <- repair_draws(x) - variable <- check_existing_variables(variable, x, regex = regex) - iteration <- check_iteration_ids(iteration, x, unique = unique) - chain <- check_chain_ids(chain, x, unique = unique) - draw <- check_draw_ids(draw, x, unique = unique) - - if (exclude) { - if (!is.null(variable)) { - variable <- setdiff(variables(x, reserved = TRUE), variable) - } - if (!is.null(iteration)) { - iteration <- setdiff(iteration_ids(x), iteration) - } - if (!is.null(chain)) { - chain <- setdiff(chain_ids(x), chain) - } - if (!is.null(draw)) { - draw <- setdiff(draw_ids(x), draw) - } - } + variable <- check_existing_variables(variable, x, regex = regex, exclude = exclude) + iteration <- check_iteration_ids(iteration, x, unique = unique, exclude = exclude) + chain <- check_chain_ids(chain, x, unique = unique, exclude = exclude) + draw <- check_draw_ids(draw, x, unique = unique, exclude = exclude) x <- prepare_subsetting(x, iteration, chain, draw) if (!is.null(draw)) { @@ -137,25 +107,10 @@ subset_draws.draws_df <- function(x, variable = NULL, iteration = NULL, x <- repair_draws(x) unique <- as_one_logical(unique) - variable <- check_existing_variables(variable, x, regex = regex) - iteration <- check_iteration_ids(iteration, x, unique = unique) - chain <- check_chain_ids(chain, x, unique = unique) - draw <- check_draw_ids(draw, x, unique = unique) - - if (exclude) { - if (!is.null(variable)) { - variable <- setdiff(variables(x, reserved = TRUE), variable) - } - if (!is.null(iteration)) { - iteration <- setdiff(iteration_ids(x), iteration) - } - if (!is.null(chain)) { - chain <- setdiff(chain_ids(x), chain) - } - if (!is.null(draw)) { - draw <- setdiff(draw_ids(x), draw) - } - } + variable <- check_existing_variables(variable, x, regex = regex, exclude= exclude) + iteration <- check_iteration_ids(iteration, x, unique = unique, exclude= exclude) + chain <- check_chain_ids(chain, x, unique = unique, exclude= exclude) + draw <- check_draw_ids(draw, x, unique = unique, exclude= exclude) x <- prepare_subsetting(x, iteration, chain, draw) x <- .subset_draws( @@ -175,25 +130,10 @@ subset_draws.draws_list <- function(x, variable = NULL, iteration = NULL, } x <- repair_draws(x) - variable <- check_existing_variables(variable, x, regex = regex) - iteration <- check_iteration_ids(iteration, x, unique = unique) - chain <- check_chain_ids(chain, x, unique = unique) - draw <- check_draw_ids(draw, x, unique = unique) - - if (exclude) { - if (!is.null(variable)) { - variable <- setdiff(variables(x, reserved = TRUE), variable) - } - if (!is.null(iteration)) { - iteration <- setdiff(iteration_ids(x), iteration) - } - if (!is.null(chain)) { - chain <- setdiff(chain_ids(x), chain) - } - if (!is.null(draw)) { - draw <- setdiff(draw_ids(x), draw) - } - } + variable <- check_existing_variables(variable, x, regex = regex, exclude = exclude) + iteration <- check_iteration_ids(iteration, x, unique = unique, exclude = exclude) + chain <- check_chain_ids(chain, x, unique = unique, exclude = exclude) + draw <- check_draw_ids(draw, x, unique = unique, exclude = exclude) x <- prepare_subsetting(x, iteration, chain, draw) if (!is.null(draw)) { @@ -216,25 +156,10 @@ subset_draws.draws_rvars <- function(x, variable = NULL, iteration = NULL, } x <- repair_draws(x) - variable <- check_existing_variables(variable, x, regex = regex) - iteration <- check_iteration_ids(iteration, x, unique = unique) - chain <- check_chain_ids(chain, x, unique = unique) - draw <- check_draw_ids(draw, x, unique = unique) - - if (exclude) { - if (!is.null(variable)) { - variable <- setdiff(variables(x, reserved = TRUE), variable) - } - if (!is.null(iteration)) { - iteration <- setdiff(iteration_ids(x), iteration) - } - if (!is.null(chain)) { - chain <- setdiff(chain_ids(x), chain) - } - if (!is.null(draw)) { - draw <- setdiff(draw_ids(x), draw) - } - } + variable <- check_existing_variables(variable, x, regex = regex, exclude = exclude) + iteration <- check_iteration_ids(iteration, x, unique = unique, exclude= exclude) + chain <- check_chain_ids(chain, x, unique = unique, exclude= exclude) + draw <- check_draw_ids(draw, x, unique = unique, exclude= exclude) x <- prepare_subsetting(x, iteration, chain, draw) if (!is.null(draw)) {