Skip to content

Commit

Permalink
exclude non-tunable engine arguments in tunable()
Browse files Browse the repository at this point in the history
closes #1104
  • Loading branch information
simonpcouch committed Apr 5, 2024
1 parent 6e8106e commit 21a56c8
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 35 deletions.
20 changes: 10 additions & 10 deletions R/tunable.R
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ tunable.linear_reg <- function(x, ...) {
} else if (x$engine == "brulee") {
res <- add_engine_parameters(res, brulee_linear_engine_args)
}
res
res[!vapply(res$call_info, is.null, logical(1)), ]
}

#' @export
Expand All @@ -260,7 +260,7 @@ tunable.logistic_reg <- function(x, ...) {
} else if (x$engine == "brulee") {
res <- add_engine_parameters(res, brulee_logistic_engine_args)
}
res
res[!vapply(res$call_info, is.null, logical(1)), ]
}

#' @export
Expand All @@ -272,7 +272,7 @@ tunable.multinomial_reg <- function(x, ...) {
} else if (x$engine == "brulee") {
res <- add_engine_parameters(res, brulee_multinomial_engine_args)
}
res
res[!vapply(res$call_info, is.null, logical(1)), ]
}

#' @export
Expand All @@ -295,7 +295,7 @@ tunable.boost_tree <- function(x, ...) {
res$call_info[res$name == "sample_size"] <-
list(list(pkg = "dials", fun = "sample_prop"))
}
res
res[!vapply(res$call_info, is.null, logical(1)), ]
}

#' @export
Expand All @@ -310,7 +310,7 @@ tunable.rand_forest <- function(x, ...) {
} else if (x$engine == "aorsf") {
res <- add_engine_parameters(res, aorsf_engine_args)
}
res
res[!vapply(res$call_info, is.null, logical(1)), ]
}

#' @export
Expand All @@ -319,7 +319,7 @@ tunable.mars <- function(x, ...) {
if (x$engine == "earth") {
res <- add_engine_parameters(res, earth_engine_args)
}
res
res[!vapply(res$call_info, is.null, logical(1)), ]
}

#' @export
Expand All @@ -333,7 +333,7 @@ tunable.decision_tree <- function(x, ...) {
partykit_engine_args %>%
dplyr::mutate(component = "decision_tree"))
}
res
res[!vapply(res$call_info, is.null, logical(1)), ]
}

#' @export
Expand All @@ -343,7 +343,7 @@ tunable.svm_poly <- function(x, ...) {
res$call_info[res$name == "degree"] <-
list(list(pkg = "dials", fun = "prod_degree", range = c(1L, 3L)))
}
res
res[!vapply(res$call_info, is.null, logical(1)), ]
}


Expand All @@ -357,7 +357,7 @@ tunable.mlp <- function(x, ...) {
res$call_info[res$name == "epochs"] <-
list(list(pkg = "dials", fun = "epochs", range = c(5L, 500L)))
}
res
res[!vapply(res$call_info, is.null, logical(1)), ]
}

#' @export
Expand All @@ -366,7 +366,7 @@ tunable.survival_reg <- function(x, ...) {
if (x$engine == "flexsurvspline") {
res <- add_engine_parameters(res, flexsurvspline_engine_args)
}
res
res[!vapply(res$call_info, is.null, logical(1)), ]
}

# nocov end
Expand Down
41 changes: 16 additions & 25 deletions tests/testthat/_snaps/tunable.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,11 @@
Code
tunable(spec %>% set_engine("glmnet", dfmax = tune()))
Output
# A tibble: 3 x 5
# A tibble: 2 x 5
name call_info source component component_id
<chr> <list> <chr> <chr> <chr>
1 penalty <named list [2]> model_spec linear_reg main
2 mixture <named list [3]> model_spec linear_reg main
3 dfmax <NULL> model_spec linear_reg engine

# tunable.logistic_reg()

Expand Down Expand Up @@ -95,12 +94,11 @@
Code
tunable(spec %>% set_engine("glmnet", dfmax = tune()))
Output
# A tibble: 3 x 5
# A tibble: 2 x 5
name call_info source component component_id
<chr> <list> <chr> <chr> <chr>
1 penalty <named list [2]> model_spec logistic_reg main
2 mixture <named list [3]> model_spec logistic_reg main
3 dfmax <NULL> model_spec logistic_reg engine

# tunable.multinom_reg()

Expand Down Expand Up @@ -244,7 +242,7 @@
Code
tunable(spec %>% set_engine("xgboost", feval = tune()))
Output
# A tibble: 9 x 5
# A tibble: 8 x 5
name call_info source component component_id
<chr> <list> <chr> <chr> <chr>
1 tree_depth <named list [2]> model_spec boost_tree main
Expand All @@ -255,7 +253,6 @@
6 loss_reduction <named list [2]> model_spec boost_tree main
7 sample_size <named list [2]> model_spec boost_tree main
8 stop_iter <named list [2]> model_spec boost_tree main
9 feval <NULL> model_spec boost_tree engine

# tunable.rand_forest()

Expand Down Expand Up @@ -310,13 +307,12 @@
Code
tunable(spec %>% set_engine("ranger", min.bucket = tune()))
Output
# A tibble: 4 x 5
name call_info source component component_id
<chr> <list> <chr> <chr> <chr>
1 mtry <named list [2]> model_spec rand_forest main
2 trees <named list [2]> model_spec rand_forest main
3 min_n <named list [2]> model_spec rand_forest main
4 min.bucket <NULL> model_spec rand_forest engine
# A tibble: 3 x 5
name call_info source component component_id
<chr> <list> <chr> <chr> <chr>
1 mtry <named list [2]> model_spec rand_forest main
2 trees <named list [2]> model_spec rand_forest main
3 min_n <named list [2]> model_spec rand_forest main

# tunable.mars()

Expand Down Expand Up @@ -347,13 +343,12 @@
Code
tunable(spec %>% set_engine("earth", minspan = tune()))
Output
# A tibble: 4 x 5
# A tibble: 3 x 5
name call_info source component component_id
<chr> <list> <chr> <chr> <chr>
1 num_terms <named list [3]> model_spec mars main
2 prod_degree <named list [2]> model_spec mars main
3 prune_method <named list [2]> model_spec mars main
4 minspan <NULL> model_spec mars engine

# tunable.decision_tree()

Expand Down Expand Up @@ -405,13 +400,12 @@
Code
tunable(spec %>% set_engine("rpart", parms = tune()))
Output
# A tibble: 4 x 5
# A tibble: 3 x 5
name call_info source component component_id
<chr> <list> <chr> <chr> <chr>
1 tree_depth <named list [2]> model_spec decision_tree main
2 min_n <named list [2]> model_spec decision_tree main
3 cost_complexity <named list [2]> model_spec decision_tree main
4 parms <NULL> model_spec decision_tree engine

# tunable.svm_poly()

Expand Down Expand Up @@ -444,14 +438,13 @@
Code
tunable(spec %>% set_engine("kernlab", tol = tune()))
Output
# A tibble: 5 x 5
# A tibble: 4 x 5
name call_info source component component_id
<chr> <list> <chr> <chr> <chr>
1 cost <named list [3]> model_spec svm_poly main
2 degree <named list [3]> model_spec svm_poly main
3 scale_factor <named list [2]> model_spec svm_poly main
4 margin <named list [2]> model_spec svm_poly main
5 tol <NULL> model_spec svm_poly engine

# tunable.mlp()

Expand Down Expand Up @@ -511,15 +504,14 @@
Code
tunable(spec %>% set_engine("keras", ragged = tune()))
Output
# A tibble: 6 x 5
# A tibble: 5 x 5
name call_info source component component_id
<chr> <list> <chr> <chr> <chr>
1 hidden_units <named list [2]> model_spec mlp main
2 penalty <named list [2]> model_spec mlp main
3 dropout <named list [2]> model_spec mlp main
4 epochs <named list [2]> model_spec mlp main
5 activation <named list [2]> model_spec mlp main
6 ragged <NULL> model_spec mlp engine

# tunable.survival_reg()

Expand All @@ -544,8 +536,7 @@
Code
tunable(spec %>% set_engine("survival", parms = tune()))
Output
# A tibble: 1 x 5
name call_info source component component_id
<chr> <list> <chr> <chr> <chr>
1 parms <NULL> model_spec survival_reg engine
# A tibble: 0 x 5
# i 5 variables: name <chr>, call_info <list>, source <chr>, component <chr>,
# component_id <chr>

0 comments on commit 21a56c8

Please sign in to comment.