Skip to content

Commit

Permalink
Merge pull request #499 from tidymodels/cli-initial_split
Browse files Browse the repository at this point in the history
Use cli errors in `initial_split.R` and `complement.R`
  • Loading branch information
hfrick authored Jul 19, 2024
2 parents e2c724f + 81b34fc commit 041c555
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 27 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ export(validation_split)
export(validation_time_split)
export(vfold_cv)
import(vctrs)
importFrom(cli,cli_abort)
importFrom(dplyr,"%>%")
importFrom(dplyr,arrange)
importFrom(dplyr,arrange_)
Expand Down
6 changes: 3 additions & 3 deletions R/complement.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ complement.apparent_split <- function(x, ...) {

#' @export
complement.default <- function(x, ...) {
cls <- paste0("'", class(x), "'", collapse = ", ")
rlang::abort(
paste("No `complement()` method for this class(es)", cls)
x_cls <- class(x)
cli_abort(
"No {.fn complement} method for objects of class{?es}: {.cls {x_cls}}"
)
}

Expand Down
6 changes: 3 additions & 3 deletions R/initial_split.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,17 @@ initial_split <- function(data, prop = 3 / 4,
initial_time_split <- function(data, prop = 3 / 4, lag = 0, ...) {
check_dots_empty()
if (!is.numeric(prop) | prop >= 1 | prop <= 0) {
rlang::abort("`prop` must be a number on (0, 1).")
cli_abort("{.arg prop} must be a number on (0, 1).")
}

if (!is.numeric(lag) | !(lag %% 1 == 0)) {
rlang::abort("`lag` must be a whole number.")
cli_abort("{.arg lag} must be a whole number.")
}

n_train <- floor(nrow(data) * prop)

if (lag > n_train) {
rlang::abort("`lag` must be less than or equal to the number of training observations.")
cli_abort("{.arg lag} must be less than or equal to the number of training observations.")
}

split <- rsplit(data, 1:n_train, (n_train + 1 - lag):nrow(data))
Expand Down
1 change: 1 addition & 0 deletions R/rsample-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

## usethis namespace: start
#' @importFrom lifecycle deprecated
#' @importFrom cli cli_abort
## usethis namespace: end
NULL

Expand Down
12 changes: 10 additions & 2 deletions tests/testthat/_snaps/initial.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
# default time param with lag
# `initial_time_split()` error messages

Code
initial_time_split(drinks, prop = 2)
Condition
Error in `initial_time_split()`:
! `prop` must be a number on (0, 1).

---

Code
initial_time_split(drinks, lag = 12.5)
Expand All @@ -9,7 +17,7 @@
---

Code
initial_time_split(drinks, lag = 500)
initial_time_split(drinks, lag = nrow(drinks) + 1)
Condition
Error in `initial_time_split()`:
! `lag` must be less than or equal to the number of training observations.
Expand Down
22 changes: 19 additions & 3 deletions tests/testthat/_snaps/rsplit.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,27 @@
<list> <chr>
1 <split [24/8]> validation

# default complement method errors
# `complement()` error messages

Code
complement("a string")
complement(fake_rsplit)
Condition
Error in `complement()`:
! No `complement()` method for this class(es) 'character'
! No `complement()` method for objects of class: <not_an_rsplit>

---

Code
complement(fake_rsplit)
Condition
Error in `complement()`:
! No `complement()` method for objects of classes: <not_an_rsplit/really_not_an_rsplit>

---

Code
get_stored_out_id(list(out_id = NA))
Condition
Error in `get_stored_out_id()`:
! Cannot derive the assessment set for this type of resampling.

25 changes: 14 additions & 11 deletions tests/testthat/test-initial.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,23 @@ test_that("default time param with lag", {
expect_equal(nrow(ts1), ceiling(nrow(dat1) / 4) + 5)
expect_equal(tr1, dplyr::slice(dat1, 1:floor(nrow(dat1) * 3 / 4)))
expect_equal(ts1, dat1[(floor(nrow(dat1) * 3 / 4) + 1 - 5):nrow(dat1), ], ignore_attr = "row.names")
})

test_that("`initial_time_split()` error messages", {
skip_if_not_installed("modeldata")
data(drinks, package = "modeldata")
data(drinks, package = "modeldata", envir = rlang::current_env())

# Whole numbers only
expect_snapshot(
initial_time_split(drinks, lag = 12.5),
error = TRUE
)
# Lag must be less than number of training observations
expect_snapshot(
initial_time_split(drinks, lag = 500),
error = TRUE
)
expect_snapshot(error = TRUE, {
initial_time_split(drinks, prop = 2)
})

expect_snapshot(error = TRUE, {
initial_time_split(drinks, lag = 12.5)
})

expect_snapshot(error = TRUE, {
initial_time_split(drinks, lag = nrow(drinks) + 1)
})
})

test_that("default group param", {
Expand Down
18 changes: 13 additions & 5 deletions tests/testthat/test-rsplit.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,19 @@ test_that("print methods", {
})
})

test_that("default complement method errors", {
expect_snapshot(
complement("a string"),
error = TRUE
)
test_that("`complement()` error messages", {
fake_rsplit <- 1
class(fake_rsplit) <- c("not_an_rsplit")
expect_snapshot(error = TRUE, {
complement(fake_rsplit)
})
class(fake_rsplit) <- c("not_an_rsplit", "really_not_an_rsplit")
expect_snapshot(error = TRUE, {
complement(fake_rsplit)
})
expect_snapshot(error = TRUE, {
get_stored_out_id(list(out_id = NA))
})
})

test_that("as.data.frame() works for permutations with Surv object without the survival package loaded - issue #443", {
Expand Down

0 comments on commit 041c555

Please sign in to comment.