diff --git a/R/tailor.R b/R/tailor.R index 7bf7fbb..37e5d18 100644 --- a/R/tailor.R +++ b/R/tailor.R @@ -86,13 +86,13 @@ fit.tailor <- function(object, .data, outcome, estimate, probabilities = c(), columns <- list() columns$outcome <- names(tidyselect::eval_select(enquo(outcome), .data)) + check_selection(enquo(outcome), columns$outcome, "outcome") columns$estimate <- names(tidyselect::eval_select(enquo(estimate), .data)) - - probabilities <- tidyselect::eval_select(enquo(probabilities), .data) - if (length(probabilities) > 0) { - columns$probabilities <- names(probabilities) - } else { - columns$probabilities <- character(0) + check_selection(enquo(estimate), columns$estimate, "estimate") + columns$probabilities <- names(tidyselect::eval_select(enquo(probabilities), .data)) + if (any(c("probability", "everything") %in% + purrr::map_chr(object$adjustments, purrr::pluck, "inputs"))) { + check_selection(enquo(probabilities), columns$probabilities, "probabilities") } time <- tidyselect::eval_select(enquo(time), .data) diff --git a/R/utils.R b/R/utils.R index 211c4ad..ea6224b 100644 --- a/R/utils.R +++ b/R/utils.R @@ -176,3 +176,15 @@ check_method <- function(method, method } +check_selection <- function(selector, result, arg, call = caller_env()) { + if (length(result) == 0) { + cli_abort( + c( + "!" = "{.arg {arg}} must select at least one column.", + "x" = "Selector {.code {as_label(selector)}} did not match any columns \\ + in {.arg .data}." + ), + call = caller_env() + ) + } +} diff --git a/tests/testthat/_snaps/tailor.md b/tests/testthat/_snaps/tailor.md index f0d8ae0..a625ae8 100644 --- a/tests/testthat/_snaps/tailor.md +++ b/tests/testthat/_snaps/tailor.md @@ -40,3 +40,58 @@ * Adjust probability threshold to 0.2. * Add equivocal zone of size 0.1. +# error informatively with empty tidyselections + + Code + tailor_fit <- tailor() %>% adjust_probability_threshold(0.5) %>% fit( + two_class_example, outcome = "truth_WRONG", estimate = "predicted", + probabilities = tidyselect::contains("Class")) + Condition + Error in `fit()`: + ! Can't select columns that don't exist. + x Column `truth_WRONG` doesn't exist. + +--- + + Code + tailor_fit <- tailor() %>% adjust_probability_threshold(0.5) %>% fit( + two_class_example, outcome = contains("truth_WRONG"), estimate = "predicted", + probabilities = tidyselect::contains("Class")) + Condition + Error in `fit()`: + ! `outcome` must select at least one column. + x Selector `contains("truth_WRONG")` did not match any columns in `.data`. + +--- + + Code + tailor_fit <- tailor() %>% adjust_probability_threshold(0.5) %>% fit( + two_class_example, outcome = "truth", estimate = "predicted_WRONG", + probabilities = tidyselect::contains("Class")) + Condition + Error in `fit()`: + ! Can't select columns that don't exist. + x Column `predicted_WRONG` doesn't exist. + +--- + + Code + tailor_fit <- tailor() %>% adjust_probability_threshold(0.5) %>% fit( + two_class_example, outcome = "truth", estimate = contains("predicted_WRONG"), + probabilities = tidyselect::contains("Class")) + Condition + Error in `fit()`: + ! `estimate` must select at least one column. + x Selector `contains("predicted_WRONG")` did not match any columns in `.data`. + +--- + + Code + tailor_fit <- tailor() %>% adjust_probability_threshold(0.5) %>% fit( + two_class_example, outcome = contains("truth"), estimate = "predicted", + probabilities = tidyselect::contains("Class_WRONG")) + Condition + Error in `fit()`: + ! `probabilities` must select at least one column. + x Selector `tidyselect::contains("Class_WRONG")` did not match any columns in `.data`. + diff --git a/tests/testthat/test-tailor.R b/tests/testthat/test-tailor.R index 1c38236..b191a87 100644 --- a/tests/testthat/test-tailor.R +++ b/tests/testthat/test-tailor.R @@ -11,3 +11,100 @@ test_that("tailor printing", { adjust_equivocal_zone() ) }) + +test_that("error informatively with empty tidyselections", { + skip_if_not_installed("modeldata") + data("two_class_example", package = "modeldata") + + expect_no_condition( + tailor_fit <- tailor() %>% + adjust_probability_threshold(.5) %>% + fit( + two_class_example, + outcome = "truth", + estimate = "predicted", + probabilities = tidyselect::contains("Class") + ) + ) + + # outcome doesn't exist, is bare string + expect_snapshot( + error = TRUE, + tailor_fit <- tailor() %>% + adjust_probability_threshold(.5) %>% + fit( + two_class_example, + outcome = "truth_WRONG", + estimate = "predicted", + probabilities = tidyselect::contains("Class") + ) + ) + + # outcome doesn't exist, is selection helper + expect_snapshot( + error = TRUE, + tailor_fit <- tailor() %>% + adjust_probability_threshold(.5) %>% + fit( + two_class_example, + outcome = contains("truth_WRONG"), + estimate = "predicted", + probabilities = tidyselect::contains("Class") + ) + ) + + # estimate doesn't exist, is bare string + expect_snapshot( + error = TRUE, + tailor_fit <- tailor() %>% + adjust_probability_threshold(.5) %>% + fit( + two_class_example, + outcome = "truth", + estimate = "predicted_WRONG", + probabilities = tidyselect::contains("Class") + ) + ) + + # estimate doesn't exist, is selection helper + expect_snapshot( + error = TRUE, + tailor_fit <- tailor() %>% + adjust_probability_threshold(.5) %>% + fit( + two_class_example, + outcome = "truth", + estimate = contains("predicted_WRONG"), + probabilities = tidyselect::contains("Class") + ) + ) + + # probability doesn't exist, is selection helper, is needed + expect_snapshot( + error = TRUE, + tailor_fit <- tailor() %>% + adjust_probability_threshold(.5) %>% + fit( + two_class_example, + outcome = contains("truth"), + estimate = "predicted", + probabilities = tidyselect::contains("Class_WRONG") + ) + ) + + # probability doesn't exist, is selection helper, isn't needed + # (asserting here that we ought not to error on a bad selection + # if it would not be used anyway.) + # todo: need to overwrite column name for now, see #22. + two_class_example$.pred <- two_class_example$Class2 + expect_no_condition( + tailor_fit <- tailor() %>% + adjust_numeric_range(.5) %>% + fit( + two_class_example, + outcome = "Class1", + estimate = ".pred", + probabilities = tidyselect::contains("Class_WRONG") + ) + ) +})