Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support (abstract) importance in multiple learners #222

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ importFrom(R6,R6Class)
importFrom(mlr3,LearnerClassif)
importFrom(mlr3,LearnerRegr)
importFrom(mlr3,mlr_learners)
importFrom(stats,coef)
importFrom(stats,predict)
importFrom(stats,reformulate)
importFrom(utils,bibentry)
importFrom(utils,tail)
14 changes: 13 additions & 1 deletion R/LearnerClassifCVGlmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ LearnerClassifCVGlmnet = R6Class("LearnerClassifCVGlmnet",
param_set = ps,
predict_types = c("response", "prob"),
feature_types = c("logical", "integer", "numeric"),
properties = c("weights", "twoclass", "multiclass", "selected_features"),
properties = c("importance", "selected_features", "weights", "twoclass", "multiclass"),
packages = c("mlr3learners", "glmnet"),
man = "mlr3learners::mlr_learners_classif.cv_glmnet"
)
Expand All @@ -95,6 +95,18 @@ LearnerClassifCVGlmnet = R6Class("LearnerClassifCVGlmnet",
#' @return (`character()`) of feature names.
selected_features = function(lambda = NULL) {
glmnet_selected_features(self, lambda)
},

#' @description
#' Returns importance scores, calculated from the path of lambda values.
#' First, the largest `lambda` at which the feature was first included in the model
#' with a nonzero coefficient is determined.
#' Second, the [rank()] of these lambda values is calculated (using averaging for ties)
#' and returned as importance scores.
#'
#' @return (named `numeric()`) of importance scores.
importance = function() {
glmnet_importance(self)
}
),

Expand Down
14 changes: 13 additions & 1 deletion R/LearnerClassifGlmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ LearnerClassifGlmnet = R6Class("LearnerClassifGlmnet",
param_set = ps,
predict_types = c("response", "prob"),
feature_types = c("logical", "integer", "numeric"),
properties = c("weights", "twoclass", "multiclass"),
properties = c("selected_features", "importance", "weights", "twoclass", "multiclass"),
packages = c("mlr3learners", "glmnet"),
man = "mlr3learners::mlr_learners_classif.glmnet"
)
Expand All @@ -105,6 +105,18 @@ LearnerClassifGlmnet = R6Class("LearnerClassifGlmnet",
#' @return (`character()`) of feature names.
selected_features = function(lambda = NULL) {
glmnet_selected_features(self, lambda)
},

#' @description
#' Returns importance scores, calculated from the path of lambda values.
#' First, the largest `lambda` at which the feature was first included in the model
#' with a nonzero coefficient is determined.
#' Second, the [rank()] of these lambda values is calculated (using averaging for ties)
#' and returned as importance scores.
#'
#' @return (named `numeric()`) of importance scores.
importance = function() {
glmnet_importance(self)
}
),

Expand Down
11 changes: 10 additions & 1 deletion R/LearnerClassifLogReg.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ LearnerClassifLogReg = R6Class("LearnerClassifLogReg",
param_set = ps,
predict_types = c("response", "prob"),
feature_types = c("logical", "integer", "numeric", "character", "factor", "ordered"),
properties = c("weights", "twoclass", "loglik"),
properties = c("weights", "twoclass", "loglik", "importance"),
packages = c("mlr3learners", "stats"),
man = "mlr3learners::mlr_learners_classif.log_reg"
)
Expand All @@ -63,6 +63,15 @@ LearnerClassifLogReg = R6Class("LearnerClassifLogReg",
#' Extract the log-likelihood (e.g., via [stats::logLik()] from the fitted model.
loglik = function() {
extract_loglik(self)
},

#' @description
#' Importance scores as \eqn{-log_{10}()}{-log10()} transformed \eqn{p}-values,
#' extracted from [summary()].
#' Does not work if the model has been fitted on factor features with more than 2 levels.
#' @return Named `numeric()`.
importance = function() {
lin_model_importance(self)
}
),

Expand Down
1 change: 1 addition & 0 deletions R/LearnerClassifMultinom.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ LearnerClassifMultinom = R6Class("LearnerClassifMultinom",
if ("weights" %in% task$properties) {
pv$weights = task$weights$weight
}

if (!is.null(pv$summ)) {
pv$summ = as.integer(pv$summ)
}
Expand Down
14 changes: 13 additions & 1 deletion R/LearnerRegrCVGlmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ LearnerRegrCVGlmnet = R6Class("LearnerRegrCVGlmnet",
id = "regr.cv_glmnet",
param_set = ps,
feature_types = c("logical", "integer", "numeric"),
properties = c("weights", "selected_features"),
properties = c("importance", "selected_features", "weights"),
packages = c("mlr3learners", "glmnet"),
man = "mlr3learners::mlr_learners_regr.cv_glmnet"
)
Expand All @@ -95,6 +95,18 @@ LearnerRegrCVGlmnet = R6Class("LearnerRegrCVGlmnet",
#' @return (`character()`) of feature names.
selected_features = function(lambda = NULL) {
glmnet_selected_features(self, lambda)
},

#' @description
#' Returns importance scores, calculated from the path of lambda values.
#' First, the largest `lambda` at which the feature was first included in the model
#' with a nonzero coefficient is determined.
#' Second, the [rank()] of these lambda values is calculated (using averaging for ties)
#' and returned as importance scores.
#'
#' @return (named `numeric()`) of importance scores.
importance = function() {
glmnet_importance(self)
}
),

Expand Down
14 changes: 13 additions & 1 deletion R/LearnerRegrGlmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ LearnerRegrGlmnet = R6Class("LearnerRegrGlmnet",
id = "regr.glmnet",
param_set = ps,
feature_types = c("logical", "integer", "numeric"),
properties = "weights",
properties = c("weights", "selected_features", "importance"),
packages = c("mlr3learners", "glmnet"),
man = "mlr3learners::mlr_learners_regr.glmnet"
)
Expand All @@ -95,6 +95,18 @@ LearnerRegrGlmnet = R6Class("LearnerRegrGlmnet",
#' @return (`character()`) of feature names.
selected_features = function(lambda = NULL) {
glmnet_selected_features(self, lambda)
},

#' @description
#' Returns importance scores, calculated from the path of lambda values.
#' First, the largest `lambda` at which the feature was first included in the model
#' with a nonzero coefficient is determined.
#' Second, the [rank()] of these lambda values is calculated (using averaging for ties)
#' and returned as importance scores.
#'
#' @return (named `numeric()`) of importance scores.
importance = function() {
glmnet_importance(self)
}
),

Expand Down
11 changes: 10 additions & 1 deletion R/LearnerRegrLM.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ LearnerRegrLM = R6Class("LearnerRegrLM",
param_set = ps,
predict_types = c("response", "se"),
feature_types = c("logical", "integer", "numeric", "factor", "character"),
properties = c("weights", "loglik"),
properties = c("weights", "loglik", "importance"),
packages = c("mlr3learners", "stats"),
man = "mlr3learners::mlr_learners_regr.lm"
)
Expand All @@ -51,6 +51,15 @@ LearnerRegrLM = R6Class("LearnerRegrLM",
#' Extract the log-likelihood (e.g., via [stats::logLik()] from the fitted model.
loglik = function() {
extract_loglik(self)
},

#' @description
#' Importance scores as \eqn{-log_{10}()}{-log10()} transformed \eqn{p}-values,
#' extracted from [summary()].
#' Does not work if the model has been fitted on factor features with more than 2 levels.
#' @return Named `numeric()`.
importance = function() {
lin_model_importance(self)
}
),

Expand Down
14 changes: 13 additions & 1 deletion R/LearnerSurvCVGlmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ LearnerSurvCVGlmnet = R6Class("LearnerSurvCVGlmnet",
param_set = ps,
feature_types = c("logical", "integer", "numeric"),
predict_types = c("crank", "lp"),
properties = c("weights", "selected_features"),
properties = c("importance", "selected_features", "weights"),
packages = c("mlr3learners", "glmnet"),
man = "mlr3learners::mlr_learners_surv.cv_glmnet"
)
Expand All @@ -91,6 +91,18 @@ LearnerSurvCVGlmnet = R6Class("LearnerSurvCVGlmnet",
#' @return (`character()`) of feature names.
selected_features = function(lambda = NULL) {
glmnet_selected_features(self, lambda)
},

#' @description
#' Returns importance scores, calculated from the path of lambda values.
#' First, the largest `lambda` at which the feature was first included in the model
#' with a nonzero coefficient is determined.
#' Second, the [rank()] of these lambda values is calculated (using averaging for ties)
#' and returned as importance scores.
#'
#' @return (named `numeric()`) of importance scores.
importance = function() {
glmnet_importance(self)
}
),

Expand Down
14 changes: 13 additions & 1 deletion R/LearnerSurvGlmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ LearnerSurvGlmnet = R6Class("LearnerSurvGlmnet",
param_set = ps,
feature_types = c("logical", "integer", "numeric"),
predict_types = c("crank", "lp"),
properties = c("weights", "selected_features"),
properties = c("importance", "selected_features", "weights"),
packages = c("mlr3learners", "glmnet"),
man = "mlr3learners::mlr_learners_surv.glmnet"
)
Expand All @@ -90,6 +90,18 @@ LearnerSurvGlmnet = R6Class("LearnerSurvGlmnet",
#' @return (`character()`) of feature names.
selected_features = function(lambda = NULL) {
glmnet_selected_features(self, lambda)
},

#' @description
#' Returns importance scores, calculated from the path of lambda values.
#' First, the largest `lambda` at which the feature was first included in the model
#' with a nonzero coefficient is determined.
#' Second, the [rank()] of these lambda values is calculated (using averaging for ties)
#' and returned as importance scores.
#'
#' @return (named `numeric()`) of importance scores.
importance = function() {
glmnet_importance(self)
}
),

Expand Down
5 changes: 2 additions & 3 deletions R/helpers.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
opts_default_contrasts = list(contrasts = c("contr.treatment", "contr.poly"))

# p = probability for levs[2] => matrix with probs for levs[1] and levs[2]
pvec2mat = function(p, levs) {
stopifnot(is.numeric(p))
Expand Down Expand Up @@ -42,6 +44,3 @@ extract_loglik = function(self) {
}
stats::logLik(self$model)
}


opts_default_contrasts = list(contrasts = c("contr.treatment", "contr.poly"))
34 changes: 34 additions & 0 deletions R/helpers_glmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,40 @@ glmnet_selected_features = function(self, lambda = NULL) {
}


glmnet_importance = function(self) {
find_lambda = function(M) {
pos = apply(M, 1L, function(x) {
i = wf(x == 0, use.names = FALSE)
if (length(i)) i else Inf
})
}

model = self$model$glmnet.fit %??% self$model
lambdas = model$lambda
M = coef(model)

if (is.list(M)) {
names(M)
rownames(M$virginica)

} else {
# * remove intercept row
# * reorder with increasing lambda
M = M[rownames(M) != "(Intercept)", order(lambdas), drop = FALSE]
}



# find position of smallest lambda with beta being penalized to 0
pos = apply(M, 1L, function(x) {
i = wf(x == 0, use.names = FALSE)
if (length(i)) i else Inf
})

sort(rank(pos, ties.method = "average"), decreasing = TRUE)
}


glmnet_invoke = function(data, target, pv, cv = FALSE) {
saved_ctrl = glmnet::glmnet.control()
on.exit(invoke(glmnet::glmnet.control, .args = saved_ctrl))
Expand Down
20 changes: 20 additions & 0 deletions R/helpers_lin_models.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
lin_model_importance = function(self) {
task = self$state$train_task
lvls = task$levels(task$feature_names, include_logicals = TRUE)
nlvls = lengths(lvls)
if (any(nlvls > 2L)) {
stopf("Importance cannot be extracted for models fitted on factors with more than 2 features")
}

pvals = summary(self$model)$coefficients[, 4L]
pvals = pvals[names(pvals) != "(Intercept)"]

# remove the appended 2nd level for binary factor levels
ii = (nlvls == 2L)
pvals = rename(pvals,
old = paste0(names(nlvls)[ii], map_chr(lvls[ii], tail, 1L)),
new = names(nlvls)[ii]
)

sort(-log10(pvals), decreasing = TRUE)
}
3 changes: 2 additions & 1 deletion R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
#' @import checkmate
#' @importFrom R6 R6Class
#' @importFrom mlr3 mlr_learners LearnerClassif LearnerRegr
#' @importFrom stats predict reformulate
#' @importFrom stats predict coef reformulate
#' @importFrom utils tail
#'
#' @description
#' More learners are implemented in the [mlr3extralearners package](https://github.com/mlr-org/mlr3extralearners).
Expand Down
18 changes: 18 additions & 0 deletions man/mlr_learners_classif.cv_glmnet.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 18 additions & 0 deletions man/mlr_learners_classif.glmnet.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading