Skip to content

Commit

Permalink
Merge pull request #548 from tidymodels/checks-vfold
Browse files Browse the repository at this point in the history
Update input checks for `vfold_cv.R`
  • Loading branch information
hfrick authored Sep 23, 2024
2 parents b02f57d + a74d3a4 commit bf70b7b
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 41 deletions.
2 changes: 1 addition & 1 deletion R/clustering.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ clustering_cv <- function(data,
distance_function = "dist",
cluster_function = c("kmeans", "hclust"),
...) {
check_repeats(repeats)
check_number_whole(repeats, min = 1)

if (!rlang::is_function(cluster_function)) {
cluster_function <- rlang::arg_match(cluster_function)
Expand Down
30 changes: 14 additions & 16 deletions R/vfold.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ vfold_cv <- function(data, v = 10, repeats = 1,
}

check_strata(strata, data)
check_repeats(repeats)
check_number_whole(repeats, min = 1)

if (repeats == 1) {
split_objs <- vfold_splits(
Expand Down Expand Up @@ -213,7 +213,7 @@ vfold_splits <- function(data, v = 10, strata = NULL, breaks = 4, pool = 0.1, pr
#' @export
group_vfold_cv <- function(data, group = NULL, v = NULL, repeats = 1, balance = c("groups", "observations"), ..., strata = NULL, pool = 0.1) {
check_dots_empty()
check_repeats(repeats)
check_number_whole(repeats, min = 1)
group <- validate_group({{ group }}, data)
balance <- rlang::arg_match(balance)

Expand Down Expand Up @@ -331,23 +331,24 @@ add_vfolds <- function(x, v) {
}

check_v <- function(v, max_v, rows = "rows", prevent_loo = TRUE, call = rlang::caller_env()) {
if (!is.numeric(v) || length(v) != 1 || v < 2) {
cli_abort("{.arg v} must be a single positive integer greater than 1.", call = call)
} else if (v > max_v) {
check_number_whole(v, min = 2, call = call)

if (v > max_v) {
cli_abort(
"The number of {rows} is less than {.arg v} = {.val {v}}.",
call = call
)
} else if (prevent_loo && isTRUE(v == max_v)) {
}
if (prevent_loo && isTRUE(v == max_v)) {
cli_abort(c(
"Leave-one-out cross-validation is not supported by this function.",
"x" = "You set `v` to `nrow(data)`, which would result in a leave-one-out cross-validation.",
"i" = "Use `loo_cv()` in this case."
"x" = "You set {.arg v} to {.code nrow(data)}, which would result in a leave-one-out cross-validation.",
"i" = "Use {.fn loo_cv} in this case."
), call = call)
}
}

check_grouped_strata <- function(group, strata, pool, data) {
check_grouped_strata <- function(group, strata, pool, data, call = caller_env()) {

strata <- tidyselect::vars_select(names(data), !!enquo(strata))

Expand All @@ -363,14 +364,11 @@ check_grouped_strata <- function(group, strata, pool, data) {

if (nrow(vctrs::vec_unique(grouped_table)) !=
nrow(vctrs::vec_unique(grouped_table["group"]))) {
cli_abort("{.arg strata} must be constant across all members of each {.arg group}.")
cli_abort(
"{.field strata} must be constant across all members of each {.field group}.",
call = call
)
}

strata
}

check_repeats <- function(repeats, call = rlang::caller_env()) {
if (!is.numeric(repeats) || length(repeats) != 1 || repeats < 1) {
cli_abort("{.arg repeats} must be a single positive integer.", call = call)
}
}
8 changes: 4 additions & 4 deletions tests/testthat/_snaps/clustering.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
clustering_cv(iris, Sepal.Length, v = -500)
Condition
Error in `clustering_cv()`:
! `v` must be a single positive integer greater than 1.
! `v` must be a whole number larger than or equal to 2, not the number -500.

---

Expand All @@ -36,23 +36,23 @@
clustering_cv(Orange, v = 1, vars = "Tree")
Condition
Error in `clustering_cv()`:
! `v` must be a single positive integer greater than 1.
! `v` must be a whole number larger than or equal to 2, not the number 1.

---

Code
clustering_cv(Orange, repeats = 0)
Condition
Error in `clustering_cv()`:
! `repeats` must be a single positive integer.
! `repeats` must be a whole number larger than or equal to 1, not the number 0.

---

Code
clustering_cv(Orange, repeats = NULL)
Condition
Error in `clustering_cv()`:
! `repeats` must be a single positive integer.
! `repeats` must be a whole number, not `NULL`.

---

Expand Down
40 changes: 24 additions & 16 deletions tests/testthat/_snaps/vfold.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,29 +41,29 @@
! strata cannot be a <Surv> object.
i Use the time or event variable directly.

# bad args
# v arg is checked

Code
vfold_cv(iris, v = -500)
Condition
Error in `vfold_cv()`:
! `v` must be a single positive integer greater than 1.
! `v` must be a whole number larger than or equal to 2, not the number -500.

---

Code
vfold_cv(iris, v = 1)
Condition
Error in `vfold_cv()`:
! `v` must be a single positive integer greater than 1.
! `v` must be a whole number larger than or equal to 2, not the number 1.

---

Code
vfold_cv(iris, v = NULL)
Condition
Error in `vfold_cv()`:
! `v` must be a single positive integer greater than 1.
! `v` must be a whole number, not `NULL`.

---

Expand All @@ -76,36 +76,36 @@
---

Code
vfold_cv(iris, v = 150, repeats = 2)
vfold_cv(mtcars, v = nrow(mtcars))
Condition
Error in `vfold_cv()`:
! Repeated resampling when `v` is 150 would create identical resamples.
! Leave-one-out cross-validation is not supported by this function.
x You set `v` to `nrow(data)`, which would result in a leave-one-out cross-validation.
i Use `loo_cv()` in this case.

---
# repeats arg is checked

Code
vfold_cv(Orange, repeats = 0)
vfold_cv(iris, v = 150, repeats = 2)
Condition
Error in `vfold_cv()`:
! `repeats` must be a single positive integer.
! Repeated resampling when `v` is 150 would create identical resamples.

---

Code
vfold_cv(Orange, repeats = NULL)
vfold_cv(Orange, repeats = 0)
Condition
Error in `vfold_cv()`:
! `repeats` must be a single positive integer.
! `repeats` must be a whole number larger than or equal to 1, not the number 0.

---

Code
vfold_cv(mtcars, v = nrow(mtcars))
vfold_cv(Orange, repeats = NULL)
Condition
Error in `vfold_cv()`:
! Leave-one-out cross-validation is not supported by this function.
x You set `v` to `nrow(data)`, which would result in a leave-one-out cross-validation.
i Use `loo_cv()` in this case.
! `repeats` must be a whole number, not `NULL`.

# printing

Expand Down Expand Up @@ -191,7 +191,7 @@
group_vfold_cv(Orange, v = 1, group = "Tree")
Condition
Error in `group_vfold_cv()`:
! `v` must be a single positive integer greater than 1.
! `v` must be a whole number larger than or equal to 2, not the number 1.

# grouping -- other balance methods

Expand Down Expand Up @@ -286,6 +286,14 @@
10 <split [96051/3949]> Resample10
# i 20 more rows

# grouping fails for strata not constant across group members

Code
group_vfold_cv(sample_data, group, v = 5, strata = outcome)
Condition
Error in `group_vfold_cv()`:
! strata must be constant across all members of each group.

# grouping -- printing

Code
Expand Down
40 changes: 36 additions & 4 deletions tests/testthat/test-vfold.R
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ test_that("strata arg is checked", {
})
})

test_that("bad args", {
test_that("v arg is checked", {
expect_snapshot(error = TRUE, {
vfold_cv(iris, v = -500)
})
Expand All @@ -117,6 +117,12 @@ test_that("bad args", {
expect_snapshot(error = TRUE, {
vfold_cv(iris, v = 500)
})
expect_snapshot(error = TRUE, {
vfold_cv(mtcars, v = nrow(mtcars))
})
})

test_that("repeats arg is checked", {
expect_snapshot(error = TRUE, {
vfold_cv(iris, v = 150, repeats = 2)
})
Expand All @@ -126,9 +132,6 @@ test_that("bad args", {
expect_snapshot(error = TRUE, {
vfold_cv(Orange, repeats = NULL)
})
expect_snapshot(error = TRUE, {
vfold_cv(mtcars, v = nrow(mtcars))
})
})

test_that("printing", {
Expand Down Expand Up @@ -403,6 +406,35 @@ test_that("grouping -- strata", {
)
})

test_that("grouping fails for strata not constant across group members", {
set.seed(11)

n_common_class <- 70
n_rare_class <- 30

group_table <- tibble(
group = 1:100,
outcome = sample(c(rep(0, n_common_class), rep(1, n_rare_class)))
)
observation_table <- tibble(
group = sample(1:100, 1e5, replace = TRUE),
observation = 1:1e5
)
sample_data <- dplyr::full_join(
group_table,
observation_table,
by = "group",
multiple = "all"
)

# violate requirement
sample_data$outcome[1] <- ifelse(sample_data$outcome[1], 0, 1)

expect_snapshot(error = TRUE, {
group_vfold_cv(sample_data, group, v = 5, strata = outcome)
})
})

test_that("grouping -- repeated", {
set.seed(11)
rs2 <- group_vfold_cv(dat1, c, v = 3, repeats = 4)
Expand Down

0 comments on commit bf70b7b

Please sign in to comment.