Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache helpers #133

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions tests/testthat/helper-cache.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# tools for caching results within testthat.

is_object_available <- function(object, fail = FALSE, save_path = "saved_objects") {
cl <- match.call()
file_name <- paste0(cl$object, ".RData")
file_path <- file.path(save_path, file_name)
has_file <- file.exists(file_path)
if (fail && !has_file) {
msg <- paste0("File '", file_name, "' is not in '", save_path, "'.")
cli::cli_abort(msg)
}
has_file
}

save_object <- function(object, save_path = "saved_objects") {
cl <- match.call()
file_name <- paste0(cl$object, ".RData")
file_path <- file.path(save_path, file_name)
res <- try(save(object, file = file_path), silent = TRUE)
# returned NULL if it worked
if (is.null(res)) {
# verify
res <- file.exists(file_path)
} else {
# save failed
print(as.character(res))
res <- FALSE
}
res
}

return_object <- function(object, save_path = "saved_objects") {
cl <- match.call()
file_name <- paste0(cl$object, ".RData")
file_path <- file.path(save_path, file_name)
load(file_path)
object
}

purge_objects <- function(save_path = "saved_objects") {
all_files <- list.files(save_path, pattern = "RData$", full.names = TRUE)
res <- vapply(all_files, unlink, integer(1))
df_res <- tibble::tibble(file = names(res))
df_res$deleted <- ifelse(res == 0, TRUE, FALSE)
invisible(df_res)
}

# Example usage
if (FALSE) {
pkg <- "tune"
is_object_available(pkg)

save_object(pkg)
is_object_available(pkg)

rm(pkg)
pkg <- return_object(pkg)
pkg

file_86 <- purge_objects()
file_86
is_object_available(pkg)

is_object_available(some_other_pkg, fail = TRUE)
}
Binary file not shown.
87 changes: 53 additions & 34 deletions tests/testthat/test-survival-tune-grid.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,59 +9,78 @@ skip_if_not_installed("censored", minimum_version = "0.2.0.9000")
skip_if_not_installed("tune", minimum_version = "1.1.1.9001")
skip_if_not_installed("yardstick", minimum_version = "1.2.0.9001")

test_that("grid tuning survival models with static metric", {
test_that("grid tuning with static metric", {
skip_if_not_installed("prodlim")
skip_if_not_installed("coin") # required for partykit engine

stc_mtrc <- metric_set(concordance_survival)

# standard setup start
set.seed(1)
sim_dat <- prodlim::SimSurv(500) %>%
mutate(event_time = Surv(time, event)) %>%
select(event_time, X1, X2)

set.seed(2)
split <- initial_split(sim_dat)
sim_tr <- training(split)
sim_te <- testing(split)
sim_rs <- vfold_cv(sim_tr)

time_points <- c(10, 1, 5, 15)

mod_spec <-
decision_tree(tree_depth = tune(), min_n = 4) %>%
set_engine("partykit") %>%
set_mode("censored regression")
if (is_object_available(grid_static_res)) {
grid_static_res <- return_object(grid_static_res)
} else {
stc_mtrc <- metric_set(concordance_survival)

set.seed(1)
sim_dat <- prodlim::SimSurv(500) %>%
mutate(event_time = Surv(time, event)) %>%
select(event_time, X1, X2)

set.seed(2)
split <- initial_split(sim_dat)
sim_tr <- training(split)
sim_te <- testing(split)
sim_rs <- vfold_cv(sim_tr)

time_points <- c(10, 1, 5, 15)

mod_spec <-
decision_tree(tree_depth = tune(), min_n = 4) %>%
set_engine("partykit") %>%
set_mode("censored regression")

grid <- tibble(tree_depth = c(1, 2, 10))

gctrl <- control_grid(save_pred = TRUE)

set.seed(2193)
grid_static_res <-
mod_spec %>%
tune_grid(
event_time ~ X1 + X2,
resamples = sim_rs,
grid = grid,
metrics = stc_mtrc,
control = gctrl
)
save_object(grid_static_res)
}

expect_s3_class(grid_static_res, "tune_results")
})

grid <- tibble(tree_depth = c(1, 2, 10))

gctrl <- control_grid(save_pred = TRUE)
# standard setup end
test_that("grid tuning with static metric - check structure", {

set.seed(2193)
grid_static_res <-
mod_spec %>%
tune_grid(
event_time ~ X1 + X2,
resamples = sim_rs,
grid = grid,
metrics = stc_mtrc,
control = gctrl
)
is_object_available(grid_static_res, fail = TRUE)
grid_static_res <- return_object(grid_static_res)

expect_false(".eval_time" %in% names(grid_static_res$.metrics[[1]]))
expect_equal(
names(grid_static_res$.predictions[[1]]),
c(".pred_time", ".row", "tree_depth", "event_time", ".config")
)
})

test_that("grid tuning with static metric - autoplot", {

is_object_available(grid_static_res, fail = TRUE)
grid_static_res <- return_object(grid_static_res)

expect_snapshot_plot(
print(autoplot(grid_static_res)),
"static-metric-grid-search"
)
})


test_that("grid tuning survival models with integrated metric", {
skip_if_not_installed("prodlim")
skip_if_not_installed("coin") # required for partykit engine
Expand Down
Loading