Skip to content

Commit

Permalink
check tidyselect output in fit.tailor()
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch committed Jun 5, 2024
1 parent dde5f44 commit 1d6cd1e
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 6 deletions.
12 changes: 6 additions & 6 deletions R/tailor.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,13 @@ fit.tailor <- function(object, .data, outcome, estimate, probabilities = c(),

columns <- list()
columns$outcome <- names(tidyselect::eval_select(enquo(outcome), .data))
check_selection(enquo(outcome), columns$outcome, "outcome")
columns$estimate <- names(tidyselect::eval_select(enquo(estimate), .data))

probabilities <- tidyselect::eval_select(enquo(probabilities), .data)
if (length(probabilities) > 0) {
columns$probabilities <- names(probabilities)
} else {
columns$probabilities <- character(0)
check_selection(enquo(estimate), columns$estimate, "estimate")
columns$probabilities <- names(tidyselect::eval_select(enquo(probabilities), .data))
if (any(c("probability", "everything") %in%
purrr::map_chr(object$adjustments, purrr::pluck, "inputs"))) {
check_selection(enquo(probabilities), columns$probabilities, "probabilities")
}

time <- tidyselect::eval_select(enquo(time), .data)
Expand Down
12 changes: 12 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,15 @@ check_method <- function(method,
method
}

check_selection <- function(selector, result, arg, call = caller_env()) {
if (length(result) == 0) {
cli_abort(
c(
"!" = "{.arg {arg}} must select at least one column.",
"x" = "Selector {.code {as_label(selector)}} did not match any columns \\
in {.arg .data}."
),
call = caller_env()
)
}
}
55 changes: 55 additions & 0 deletions tests/testthat/_snaps/tailor.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,58 @@
* Adjust probability threshold to 0.2.
* Add equivocal zone of size 0.1.

# error informatively with empty tidyselections

Code
tailor_fit <- tailor() %>% adjust_probability_threshold(0.5) %>% fit(
two_class_example, outcome = "truth_WRONG", estimate = "predicted",
probabilities = tidyselect::contains("Class"))
Condition
Error in `fit()`:
! Can't select columns that don't exist.
x Column `truth_WRONG` doesn't exist.

---

Code
tailor_fit <- tailor() %>% adjust_probability_threshold(0.5) %>% fit(
two_class_example, outcome = contains("truth_WRONG"), estimate = "predicted",
probabilities = tidyselect::contains("Class"))
Condition
Error in `fit()`:
! `outcome` must select at least one column.
x Selector `contains("truth_WRONG")` did not match any columns in `.data`.

---

Code
tailor_fit <- tailor() %>% adjust_probability_threshold(0.5) %>% fit(
two_class_example, outcome = "truth", estimate = "predicted_WRONG",
probabilities = tidyselect::contains("Class"))
Condition
Error in `fit()`:
! Can't select columns that don't exist.
x Column `predicted_WRONG` doesn't exist.

---

Code
tailor_fit <- tailor() %>% adjust_probability_threshold(0.5) %>% fit(
two_class_example, outcome = "truth", estimate = contains("predicted_WRONG"),
probabilities = tidyselect::contains("Class"))
Condition
Error in `fit()`:
! `estimate` must select at least one column.
x Selector `contains("predicted_WRONG")` did not match any columns in `.data`.

---

Code
tailor_fit <- tailor() %>% adjust_probability_threshold(0.5) %>% fit(
two_class_example, outcome = contains("truth"), estimate = "predicted",
probabilities = tidyselect::contains("Class_WRONG"))
Condition
Error in `fit()`:
! `probabilities` must select at least one column.
x Selector `tidyselect::contains("Class_WRONG")` did not match any columns in `.data`.

97 changes: 97 additions & 0 deletions tests/testthat/test-tailor.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,100 @@ test_that("tailor printing", {
adjust_equivocal_zone()
)
})

test_that("error informatively with empty tidyselections", {
skip_if_not_installed("modeldata")
data("two_class_example", package = "modeldata")

expect_no_condition(
tailor_fit <- tailor() %>%
adjust_probability_threshold(.5) %>%
fit(
two_class_example,
outcome = "truth",
estimate = "predicted",
probabilities = tidyselect::contains("Class")
)
)

# outcome doesn't exist, is bare string
expect_snapshot(
error = TRUE,
tailor_fit <- tailor() %>%
adjust_probability_threshold(.5) %>%
fit(
two_class_example,
outcome = "truth_WRONG",
estimate = "predicted",
probabilities = tidyselect::contains("Class")
)
)

# outcome doesn't exist, is selection helper
expect_snapshot(
error = TRUE,
tailor_fit <- tailor() %>%
adjust_probability_threshold(.5) %>%
fit(
two_class_example,
outcome = contains("truth_WRONG"),
estimate = "predicted",
probabilities = tidyselect::contains("Class")
)
)

# estimate doesn't exist, is bare string
expect_snapshot(
error = TRUE,
tailor_fit <- tailor() %>%
adjust_probability_threshold(.5) %>%
fit(
two_class_example,
outcome = "truth",
estimate = "predicted_WRONG",
probabilities = tidyselect::contains("Class")
)
)

# estimate doesn't exist, is selection helper
expect_snapshot(
error = TRUE,
tailor_fit <- tailor() %>%
adjust_probability_threshold(.5) %>%
fit(
two_class_example,
outcome = "truth",
estimate = contains("predicted_WRONG"),
probabilities = tidyselect::contains("Class")
)
)

# probability doesn't exist, is selection helper, is needed
expect_snapshot(
error = TRUE,
tailor_fit <- tailor() %>%
adjust_probability_threshold(.5) %>%
fit(
two_class_example,
outcome = contains("truth"),
estimate = "predicted",
probabilities = tidyselect::contains("Class_WRONG")
)
)

# probability doesn't exist, is selection helper, isn't needed
# (asserting here that we ought not to error on a bad selection
# if it would not be used anyway.)
# todo: need to overwrite column name for now, see #22.
two_class_example$.pred <- two_class_example$Class2
expect_no_condition(
tailor_fit <- tailor() %>%
adjust_numeric_range(.5) %>%
fit(
two_class_example,
outcome = "Class1",
estimate = ".pred",
probabilities = tidyselect::contains("Class_WRONG")
)
)
})

0 comments on commit 1d6cd1e

Please sign in to comment.