From 21a56c85da2fff272d3f594eaa6e39acabe6e99f Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Fri, 5 Apr 2024 10:48:16 -0500 Subject: [PATCH] exclude non-tunable engine arguments in `tunable()` closes #1104 --- R/tunable.R | 20 ++++++++-------- tests/testthat/_snaps/tunable.md | 41 +++++++++++++------------------- 2 files changed, 26 insertions(+), 35 deletions(-) diff --git a/R/tunable.R b/R/tunable.R index 85c8bff29..271c470e2 100644 --- a/R/tunable.R +++ b/R/tunable.R @@ -248,7 +248,7 @@ tunable.linear_reg <- function(x, ...) { } else if (x$engine == "brulee") { res <- add_engine_parameters(res, brulee_linear_engine_args) } - res + res[!vapply(res$call_info, is.null, logical(1)), ] } #' @export @@ -260,7 +260,7 @@ tunable.logistic_reg <- function(x, ...) { } else if (x$engine == "brulee") { res <- add_engine_parameters(res, brulee_logistic_engine_args) } - res + res[!vapply(res$call_info, is.null, logical(1)), ] } #' @export @@ -272,7 +272,7 @@ tunable.multinomial_reg <- function(x, ...) { } else if (x$engine == "brulee") { res <- add_engine_parameters(res, brulee_multinomial_engine_args) } - res + res[!vapply(res$call_info, is.null, logical(1)), ] } #' @export @@ -295,7 +295,7 @@ tunable.boost_tree <- function(x, ...) { res$call_info[res$name == "sample_size"] <- list(list(pkg = "dials", fun = "sample_prop")) } - res + res[!vapply(res$call_info, is.null, logical(1)), ] } #' @export @@ -310,7 +310,7 @@ tunable.rand_forest <- function(x, ...) { } else if (x$engine == "aorsf") { res <- add_engine_parameters(res, aorsf_engine_args) } - res + res[!vapply(res$call_info, is.null, logical(1)), ] } #' @export @@ -319,7 +319,7 @@ tunable.mars <- function(x, ...) { if (x$engine == "earth") { res <- add_engine_parameters(res, earth_engine_args) } - res + res[!vapply(res$call_info, is.null, logical(1)), ] } #' @export @@ -333,7 +333,7 @@ tunable.decision_tree <- function(x, ...) { partykit_engine_args %>% dplyr::mutate(component = "decision_tree")) } - res + res[!vapply(res$call_info, is.null, logical(1)), ] } #' @export @@ -343,7 +343,7 @@ tunable.svm_poly <- function(x, ...) { res$call_info[res$name == "degree"] <- list(list(pkg = "dials", fun = "prod_degree", range = c(1L, 3L))) } - res + res[!vapply(res$call_info, is.null, logical(1)), ] } @@ -357,7 +357,7 @@ tunable.mlp <- function(x, ...) { res$call_info[res$name == "epochs"] <- list(list(pkg = "dials", fun = "epochs", range = c(5L, 500L))) } - res + res[!vapply(res$call_info, is.null, logical(1)), ] } #' @export @@ -366,7 +366,7 @@ tunable.survival_reg <- function(x, ...) { if (x$engine == "flexsurvspline") { res <- add_engine_parameters(res, flexsurvspline_engine_args) } - res + res[!vapply(res$call_info, is.null, logical(1)), ] } # nocov end diff --git a/tests/testthat/_snaps/tunable.md b/tests/testthat/_snaps/tunable.md index e2ec2c208..9f8b6ba3b 100644 --- a/tests/testthat/_snaps/tunable.md +++ b/tests/testthat/_snaps/tunable.md @@ -43,12 +43,11 @@ Code tunable(spec %>% set_engine("glmnet", dfmax = tune())) Output - # A tibble: 3 x 5 + # A tibble: 2 x 5 name call_info source component component_id 1 penalty model_spec linear_reg main 2 mixture model_spec linear_reg main - 3 dfmax model_spec linear_reg engine # tunable.logistic_reg() @@ -95,12 +94,11 @@ Code tunable(spec %>% set_engine("glmnet", dfmax = tune())) Output - # A tibble: 3 x 5 + # A tibble: 2 x 5 name call_info source component component_id 1 penalty model_spec logistic_reg main 2 mixture model_spec logistic_reg main - 3 dfmax model_spec logistic_reg engine # tunable.multinom_reg() @@ -244,7 +242,7 @@ Code tunable(spec %>% set_engine("xgboost", feval = tune())) Output - # A tibble: 9 x 5 + # A tibble: 8 x 5 name call_info source component component_id 1 tree_depth model_spec boost_tree main @@ -255,7 +253,6 @@ 6 loss_reduction model_spec boost_tree main 7 sample_size model_spec boost_tree main 8 stop_iter model_spec boost_tree main - 9 feval model_spec boost_tree engine # tunable.rand_forest() @@ -310,13 +307,12 @@ Code tunable(spec %>% set_engine("ranger", min.bucket = tune())) Output - # A tibble: 4 x 5 - name call_info source component component_id - - 1 mtry model_spec rand_forest main - 2 trees model_spec rand_forest main - 3 min_n model_spec rand_forest main - 4 min.bucket model_spec rand_forest engine + # A tibble: 3 x 5 + name call_info source component component_id + + 1 mtry model_spec rand_forest main + 2 trees model_spec rand_forest main + 3 min_n model_spec rand_forest main # tunable.mars() @@ -347,13 +343,12 @@ Code tunable(spec %>% set_engine("earth", minspan = tune())) Output - # A tibble: 4 x 5 + # A tibble: 3 x 5 name call_info source component component_id 1 num_terms model_spec mars main 2 prod_degree model_spec mars main 3 prune_method model_spec mars main - 4 minspan model_spec mars engine # tunable.decision_tree() @@ -405,13 +400,12 @@ Code tunable(spec %>% set_engine("rpart", parms = tune())) Output - # A tibble: 4 x 5 + # A tibble: 3 x 5 name call_info source component component_id 1 tree_depth model_spec decision_tree main 2 min_n model_spec decision_tree main 3 cost_complexity model_spec decision_tree main - 4 parms model_spec decision_tree engine # tunable.svm_poly() @@ -444,14 +438,13 @@ Code tunable(spec %>% set_engine("kernlab", tol = tune())) Output - # A tibble: 5 x 5 + # A tibble: 4 x 5 name call_info source component component_id 1 cost model_spec svm_poly main 2 degree model_spec svm_poly main 3 scale_factor model_spec svm_poly main 4 margin model_spec svm_poly main - 5 tol model_spec svm_poly engine # tunable.mlp() @@ -511,7 +504,7 @@ Code tunable(spec %>% set_engine("keras", ragged = tune())) Output - # A tibble: 6 x 5 + # A tibble: 5 x 5 name call_info source component component_id 1 hidden_units model_spec mlp main @@ -519,7 +512,6 @@ 3 dropout model_spec mlp main 4 epochs model_spec mlp main 5 activation model_spec mlp main - 6 ragged model_spec mlp engine # tunable.survival_reg() @@ -544,8 +536,7 @@ Code tunable(spec %>% set_engine("survival", parms = tune())) Output - # A tibble: 1 x 5 - name call_info source component component_id - - 1 parms model_spec survival_reg engine + # A tibble: 0 x 5 + # i 5 variables: name , call_info , source , component , + # component_id