Skip to content

Commit

Permalink
feat: features can be always included with the always_include colum…
Browse files Browse the repository at this point in the history
…n role (#89)
  • Loading branch information
be-marc authored Nov 17, 2023
1 parent 703fdba commit 34060fb
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 1 deletion.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# mlr3fselect (development version)

* feat: Features can be always included with the `always_include` column role.
* fix: Add `$phash()` method to `AutoFSelector`.
* fix: Include `FSelector` in hash of `AutoFSelector`.
* refactor: Change default batch size of `FSelectorRandomSearch` to 10.
Expand Down
4 changes: 3 additions & 1 deletion R/ObjectiveFSelect.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ ObjectiveFSelect = R6Class("ObjectiveFSelect",
tasks = map(private$.xss, function(x) {
state = self$task$feature_names[unlist(x)]
task = self$task$clone()
task$select(state)
always_included = task$col_roles$always_included
task$set_col_roles(always_included, "feature")
task$select(c(state, always_included))
task
})

Expand Down
4 changes: 4 additions & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
x = utils::getFromNamespace("bbotk_reflections", ns = "bbotk")
x$optimizer_properties = c(x$optimizer_properties, "requires_model")

x = utils::getFromNamespace("mlr_reflections", ns = "mlr3")
x$task_col_roles$classif = c(x$task_col_roles$classif, "always_included")
x$task_col_roles$regr = c(x$task_col_roles$regr, "always_included")

# callbacks
x = utils::getFromNamespace("mlr_callbacks", ns = "mlr3misc")
x$add("mlr3fselect.backup", load_callback_backup)
Expand Down
52 changes: 52 additions & 0 deletions tests/testthat/test_FSelectInstanceSingleCrit.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,55 @@ test_that("result$features works", {
inst$assign_result(xdt, y)
expect_character(inst$result_feature_set)
})

test_that("always include variable works", {
task = tsk("pima")
task$set_col_roles("glucose", "always_included")

learner = lrn("classif.rpart")
resampling = rsmp("cv", folds = 3)

instance = fselect(
fselector = fs("random_search", batch_size = 100),
task = task,
learner = learner,
resampling = resampling,
measure = msr("classif.ce"),
terminator = trm("evals", n_evals = 100),
store_models = TRUE
)

data = as.data.table(instance$archive)

expect_names(instance$archive$cols_x, disjunct.from = "gloucose")
expect_names(names(instance$archive$data), disjunct.from = "gloucose")
walk(data$resample_result, function(rr) {
expect_names(names(rr$learners[[1]]$state$data_prototype), must.include = "glucose")
})
})

test_that("always include variables works", {
task = tsk("pima")
task$set_col_roles(c("glucose", "age"), "always_included")

learner = lrn("classif.rpart")
resampling = rsmp("cv", folds = 3)

instance = fselect(
fselector = fs("random_search", batch_size = 100),
task = task,
learner = learner,
resampling = resampling,
measure = msr("classif.ce"),
terminator = trm("evals", n_evals = 100),
store_models = TRUE
)

data = as.data.table(instance$archive)

expect_names(instance$archive$cols_x, disjunct.from = c("glucose", "age"))
expect_names(names(instance$archive$data), disjunct.from = c("glucose", "age"))
walk(data$resample_result, function(rr) {
expect_names(names(rr$learners[[1]]$state$data_prototype), must.include = c("glucose", "age"))
})
})

0 comments on commit 34060fb

Please sign in to comment.