Skip to content

Commit

Permalink
Add importance to glmnet
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Dec 10, 2019
1 parent ad9ab82 commit 7594ad7
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
14 changes: 13 additions & 1 deletion R/LearnerClassifGlmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ LearnerClassifGlmnet = R6Class("LearnerClassifGlmnet", inherit = LearnerClassif,
param_set = ps,
predict_types = c("response", "prob"),
feature_types = c("integer", "numeric"),
properties = c("weights", "twoclass", "multiclass"),
properties = c("weights", "twoclass", "multiclass", "importance"),
packages = "glmnet",
man = "mlr3learners::mlr_learners_classif.glmnet"
)
Expand Down Expand Up @@ -111,6 +111,18 @@ LearnerClassifGlmnet = R6Class("LearnerClassifGlmnet", inherit = LearnerClassif,
}
PredictionClassif$new(task = task, prob = prob)
}
},

importance = function() {
model = self$model$glmnet.fit

res = sapply(seq_len(nrow(model$beta)), function(i) {

This comment has been minimized.

Copy link
@pat-s

pat-s Dec 14, 2019

Member

@be-marc Instead of sapply(), use mlr3misc::map_*().

ind = which(model$beta[i,] != 0)[1]
model$lambda[ind]
})

names(res) = model$beta@Dimnames[[1]]
sort(res, decreasing = TRUE)
}
)
)
14 changes: 13 additions & 1 deletion R/LearnerRegrGlmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ LearnerRegrGlmnet = R6Class("LearnerRegrGlmnet", inherit = LearnerRegr,
id = "regr.glmnet",
param_set = ps,
feature_types = c("integer", "numeric"),
properties = "weights",
properties = c("weights", "importance"),
packages = "glmnet",
man = "mlr3learners::mlr_learners_regr.glmnet"
)
Expand Down Expand Up @@ -105,6 +105,18 @@ LearnerRegrGlmnet = R6Class("LearnerRegrGlmnet", inherit = LearnerRegr,

response = invoke(predict, self$model, newx = newdata, type = "response", .args = pars)
PredictionRegr$new(task = task, response = drop(response))
},

importance = function() {
model = self$model$glmnet.fit

res = sapply(seq_len(nrow(model$beta)), function(i) {
ind = which(model$beta[i,] != 0)[1]
model$lambda[ind]
})

names(res) = model$beta@Dimnames[[1]]
sort(res, decreasing = TRUE)
}
)
)

0 comments on commit 7594ad7

Please sign in to comment.