diff --git a/tests/testthat/helper-cache.R b/tests/testthat/helper-cache.R new file mode 100644 index 00000000..086889d0 --- /dev/null +++ b/tests/testthat/helper-cache.R @@ -0,0 +1,65 @@ +# tools for caching results within testthat. + +is_object_available <- function(object, fail = FALSE, save_path = "saved_objects") { + cl <- match.call() + file_name <- paste0(cl$object, ".RData") + file_path <- file.path(save_path, file_name) + has_file <- file.exists(file_path) + if (fail && !has_file) { + msg <- paste0("File '", file_name, "' is not in '", save_path, "'.") + cli::cli_abort(msg) + } + has_file +} + +save_object <- function(object, save_path = "saved_objects") { + cl <- match.call() + file_name <- paste0(cl$object, ".RData") + file_path <- file.path(save_path, file_name) + res <- try(save(object, file = file_path), silent = TRUE) + # returned NULL if it worked + if (is.null(res)) { + # verify + res <- file.exists(file_path) + } else { + # save failed + print(as.character(res)) + res <- FALSE + } + res +} + +return_object <- function(object, save_path = "saved_objects") { + cl <- match.call() + file_name <- paste0(cl$object, ".RData") + file_path <- file.path(save_path, file_name) + load(file_path) + object +} + +purge_objects <- function(save_path = "saved_objects") { + all_files <- list.files(save_path, pattern = "RData$", full.names = TRUE) + res <- vapply(all_files, unlink, integer(1)) + df_res <- tibble::tibble(file = names(res)) + df_res$deleted <- ifelse(res == 0, TRUE, FALSE) + invisible(df_res) +} + +# Example usage +if (FALSE) { + pkg <- "tune" + is_object_available(pkg) + + save_object(pkg) + is_object_available(pkg) + + rm(pkg) + pkg <- return_object(pkg) + pkg + + file_86 <- purge_objects() + file_86 + is_object_available(pkg) + + is_object_available(some_other_pkg, fail = TRUE) +} diff --git a/tests/testthat/saved_objects/grid_static_res.RData b/tests/testthat/saved_objects/grid_static_res.RData new file mode 100644 index 00000000..00db0adc Binary files /dev/null and b/tests/testthat/saved_objects/grid_static_res.RData differ diff --git a/tests/testthat/test-survival-tune-grid.R b/tests/testthat/test-survival-tune-grid.R index 6fbddace..1d98fd29 100644 --- a/tests/testthat/test-survival-tune-grid.R +++ b/tests/testthat/test-survival-tune-grid.R @@ -9,52 +9,70 @@ skip_if_not_installed("censored", minimum_version = "0.2.0.9000") skip_if_not_installed("tune", minimum_version = "1.1.1.9001") skip_if_not_installed("yardstick", minimum_version = "1.2.0.9001") -test_that("grid tuning survival models with static metric", { +test_that("grid tuning with static metric", { skip_if_not_installed("prodlim") skip_if_not_installed("coin") # required for partykit engine - stc_mtrc <- metric_set(concordance_survival) - - # standard setup start - set.seed(1) - sim_dat <- prodlim::SimSurv(500) %>% - mutate(event_time = Surv(time, event)) %>% - select(event_time, X1, X2) - - set.seed(2) - split <- initial_split(sim_dat) - sim_tr <- training(split) - sim_te <- testing(split) - sim_rs <- vfold_cv(sim_tr) - - time_points <- c(10, 1, 5, 15) - - mod_spec <- - decision_tree(tree_depth = tune(), min_n = 4) %>% - set_engine("partykit") %>% - set_mode("censored regression") + if (is_object_available(grid_static_res)) { + grid_static_res <- return_object(grid_static_res) + } else { + stc_mtrc <- metric_set(concordance_survival) + + set.seed(1) + sim_dat <- prodlim::SimSurv(500) %>% + mutate(event_time = Surv(time, event)) %>% + select(event_time, X1, X2) + + set.seed(2) + split <- initial_split(sim_dat) + sim_tr <- training(split) + sim_te <- testing(split) + sim_rs <- vfold_cv(sim_tr) + + time_points <- c(10, 1, 5, 15) + + mod_spec <- + decision_tree(tree_depth = tune(), min_n = 4) %>% + set_engine("partykit") %>% + set_mode("censored regression") + + grid <- tibble(tree_depth = c(1, 2, 10)) + + gctrl <- control_grid(save_pred = TRUE) + + set.seed(2193) + grid_static_res <- + mod_spec %>% + tune_grid( + event_time ~ X1 + X2, + resamples = sim_rs, + grid = grid, + metrics = stc_mtrc, + control = gctrl + ) + save_object(grid_static_res) + } + + expect_s3_class(grid_static_res, "tune_results") +}) - grid <- tibble(tree_depth = c(1, 2, 10)) - gctrl <- control_grid(save_pred = TRUE) - # standard setup end +test_that("grid tuning with static metric - check structure", { - set.seed(2193) - grid_static_res <- - mod_spec %>% - tune_grid( - event_time ~ X1 + X2, - resamples = sim_rs, - grid = grid, - metrics = stc_mtrc, - control = gctrl - ) + is_object_available(grid_static_res, fail = TRUE) + grid_static_res <- return_object(grid_static_res) expect_false(".eval_time" %in% names(grid_static_res$.metrics[[1]])) expect_equal( names(grid_static_res$.predictions[[1]]), c(".pred_time", ".row", "tree_depth", "event_time", ".config") ) +}) + +test_that("grid tuning with static metric - autoplot", { + + is_object_available(grid_static_res, fail = TRUE) + grid_static_res <- return_object(grid_static_res) expect_snapshot_plot( print(autoplot(grid_static_res)), @@ -62,6 +80,7 @@ test_that("grid tuning survival models with static metric", { ) }) + test_that("grid tuning survival models with integrated metric", { skip_if_not_installed("prodlim") skip_if_not_installed("coin") # required for partykit engine