From 09540cbb199b60e11a59ff9746d25872e364d6c9 Mon Sep 17 00:00:00 2001 From: Marc Becker <33069354+be-marc@users.noreply.github.com> Date: Fri, 15 Dec 2023 21:55:46 +0100 Subject: [PATCH] feat: add importance to result of rfe (#93) * feat: add importance to result of rfe * chore: update news * test: bracket * fix: expect numeric --- NEWS.md | 1 + R/FSelectInstanceSingleCrit.R | 3 ++- R/FSelectorRFE.R | 13 +++++++++++++ inst/testthat/helper_fselector.R | 2 +- tests/testthat/test_FSelectorRFE.R | 2 ++ 5 files changed, 19 insertions(+), 2 deletions(-) diff --git a/NEWS.md b/NEWS.md index f2e4b0a2..f8882873 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,6 @@ # mlr3fselect (development version) +* feat: Add importance scores to result of `FSelectorRFE`. * feat: Add number of features to `as.data.table.ArchiveFSelect()`. * feat: Features can be always included with the `always_include` column role. * fix: Add `$phash()` method to `AutoFSelector`. diff --git a/R/FSelectInstanceSingleCrit.R b/R/FSelectInstanceSingleCrit.R index 90261e53..9bb984f1 100644 --- a/R/FSelectInstanceSingleCrit.R +++ b/R/FSelectInstanceSingleCrit.R @@ -127,7 +127,8 @@ FSelectInstanceSingleCrit = R6Class("FSelectInstanceSingleCrit", #' Optimal outcome. assign_result = function(xdt, y) { # Add feature names to result for easy task subsetting - features = list(self$objective$task$feature_names[as.logical(xdt)]) + feature_names = self$objective$task$feature_names + features = list(feature_names[as.logical(xdt[, feature_names, with = FALSE])]) xdt[, features := list(features)] assert_data_table(xdt, nrows = 1L) assert_names(names(xdt), must.include = self$search_space$ids()) diff --git a/R/FSelectorRFE.R b/R/FSelectorRFE.R index a496ac4a..138e1b33 100644 --- a/R/FSelectorRFE.R +++ b/R/FSelectorRFE.R @@ -150,6 +150,19 @@ FSelectorRFE = R6Class("FSelectorRFE", subsets = rfe_subsets(n, n_features, feature_number, subset_sizes, feature_fraction) rfe_workhorse(inst, subsets, recursive, aggregation) + }, + + .assign_result = function(inst) { + assert_class(inst, "FSelectInstanceSingleCrit") + res = inst$archive$best() + + xdt = res[, c(inst$search_space$ids(), "importance"), with = FALSE] + + # unlist keeps name! + y = unlist(res[, inst$archive$cols_y, with = FALSE]) + inst$assign_result(xdt, y) + + invisible(NULL) } ) ) diff --git a/inst/testthat/helper_fselector.R b/inst/testthat/helper_fselector.R index 77656c69..fd832bc0 100644 --- a/inst/testthat/helper_fselector.R +++ b/inst/testthat/helper_fselector.R @@ -16,7 +16,7 @@ test_fselector = function(.key, ..., term_evals = NULL, store_models = FALSE) { # result checks archive = inst$archive expect_data_table(inst$result, nrows = 1) - expect_names(names(inst$result), identical.to = c("x1", "x2", "x3", "x4", "features", "dummy")) + expect_names(names(inst$result), must.include = c("x1", "x2", "x3", "x4", "features", "dummy")) expect_subset(inst$result$features[[1]], c("x1", "x2", "x3", "x4")) expect_data_table(inst$result_x_search_space, nrows = 1, ncols = 4, types = "logical") expect_names(names(inst$result_x_search_space), identical.to = c("x1", "x2", "x3", "x4")) diff --git a/tests/testthat/test_FSelectorRFE.R b/tests/testthat/test_FSelectorRFE.R index b34966b5..8f40d6a0 100644 --- a/tests/testthat/test_FSelectorRFE.R +++ b/tests/testthat/test_FSelectorRFE.R @@ -1,6 +1,8 @@ test_that("importance is stored in the archive", { z = test_fselector("rfe", store_models = TRUE) a = z$inst$archive$data + expect_names(names(z$inst$result), must.include = "importance") + expect_numeric(z$inst$result$importance[[1]]) expect_names(names(z$inst$archive$data), must.include = "importance") pwalk(a, function(x1, x2, x3, x4, importance, ...) expect_equal(x1 + x2 + x3 + x4, length(importance))) })