Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: regression learners and estimate memory functions #18

Merged
merged 9 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,9 @@ Suggests:
xgboost
Remotes:
catboost/catboost/catboost/R-package,
mlr-org/mlr3@learner_size,
mlr-org/mlr3extralearners@mlr3automl,
mlr-org/mlr3learners@mlr3automl,
mlr-org/mlr3mbo@adbo,
mlr-org/mlr3pipelines
mlr-org/mlr3,
mlr-org/mlr3extralearners,
mlr-org/mlr3mbo@adbo
Config/testthat/edition: 3
Config/testthat/parallel: false
Encoding: UTF-8
Expand All @@ -67,7 +65,17 @@ Collate:
'LearnerClassifAutoSVM.R'
'LearnerClassifAutoXgboost.R'
'LearnerRegrAuto.R'
'helper.R'
'LearnerRegrAutoCatboost.R'
'LearnerRegrAutoGlmnet.R'
'LearnerRegrAutoKKNN.R'
'LearnerRegrAutoLightGBM.R'
'LearnerRegrAutoNnet.R'
'LearnerRegrAutoRanger.R'
'LearnerRegrAutoSVM.R'
'LearnerRegrAutoXgboost.R'
'build_graph.R'
'estimate_memory.R'
'helper.R'
'internal_measure.R'
'train_auto.R'
'zzz.R'
27 changes: 27 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,9 +1,36 @@
# Generated by roxygen2: do not edit by hand

S3method(estimate_memory,Learner)
S3method(estimate_memory,LearnerClassifCatboost)
S3method(estimate_memory,LearnerClassifLightGBM)
S3method(estimate_memory,LearnerClassifRanger)
S3method(estimate_memory,LearnerClassifXgboost)
S3method(estimate_memory,LearnerRegrCatboost)
S3method(estimate_memory,LearnerRegrLightGBM)
S3method(estimate_memory,LearnerRegrRanger)
S3method(estimate_memory,LearnerRegrXgboost)
export(LearnerClassifAuto)
export(LearnerClassifAutoCatboost)
export(LearnerClassifAutoGlmnet)
export(LearnerClassifAutoKKNN)
export(LearnerClassifAutoLightGBM)
export(LearnerClassifAutoNnet)
export(LearnerClassifAutoRanger)
export(LearnerClassifAutoSVM)
export(LearnerClassifAutoXgboost)
export(LearnerRegrAuto)
export(LearnerRegrAutoCatboost)
export(LearnerRegrAutoGlmnet)
export(LearnerRegrAutoKKNN)
export(LearnerRegrAutoLightGBM)
export(LearnerRegrAutoNnet)
export(LearnerRegrAutoRanger)
export(LearnerRegrAutoSVM)
export(LearnerRegrAutoXgboost)
export(catboost_internal_measure)
export(estimate_memory)
export(lightgbm_internal_measure)
export(xgboost_internal_measure)
import(R6)
import(checkmate)
import(data.table)
Expand Down
80 changes: 47 additions & 33 deletions R/LearnerClassifAuto.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#' Classification auto learner.
#'
#' @template param_id
#' @template param_learner_ids
#'
#' @export
LearnerClassifAuto = R6Class("LearnerClassifAuto",
Expand All @@ -21,47 +22,52 @@ LearnerClassifAuto = R6Class("LearnerClassifAuto",

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id = "classif.auto") {
initialize = function(
id = "classif.auto",
learner_ids = c("glmnet", "kknn", "lda", "nnet", "ranger", "svm", "xgboost", "catboost", "extra_trees", "lightgbm")
) {
assert_subset(learner_ids, c("glmnet", "kknn", "lda", "nnet", "ranger", "svm", "xgboost", "catboost", "extra_trees", "lightgbm"))
if (all(learner_ids %in% c("lda", "extra_trees"))) {
stop("Learner 'lda' and 'extra_trees' must be combined with other learners")
}

private$.learner_ids = learner_ids
self$tuning_space = tuning_space[private$.learner_ids]

param_set = ps(
# learner
learner_ids = p_uty(default = c("glmnet", "kknn", "lda", "nnet", "ranger", "svm", "xgboost", "catboost", "extra_trees", "lightgbm"),
custom_check = function(x) {
if (all(x %in% c("lda", "extra_trees"))) {
return("Learner 'lda' and 'extra_trees' must be combined with other learners")
}
check_subset(x, c("glmnet", "kknn", "lda", "nnet", "ranger", "svm", "xgboost", "catboost", "extra_trees", "lightgbm"))
}),
learner_timeout = p_int(lower = 1L, default = 900L),
xgboost_eval_metric = p_uty(),
catboost_eval_metric = p_uty(),
lightgbm_eval_metric = p_uty(),
learner_timeout = p_int(lower = 1L, default = 900L, tags = c("train", "super")),
# internal eval metric
xgboost_eval_metric = p_uty(tags = c("train", "xgboost")),
catboost_eval_metric = p_uty(tags = c("train", "catboost")),
lightgbm_eval_metric = p_uty(tags = c("train", "lightgbm")),
# system
max_nthread = p_int(lower = 1L, default = 1L),
max_memory = p_int(lower = 1L, default = 32000L),
max_nthread = p_int(lower = 1L, default = 1L, tags = c("train", "catboost", "lightgbm", "ranger", "xgboost")),
max_memory = p_int(lower = 1L, default = 32000L, tags = c("train", "catboost", "lightgbm", "ranger", "xgboost")),
# large data
large_data_size = p_int(lower = 1L, default = 1e6),
large_data_learner_ids = p_uty(),
large_data_nthread = p_int(lower = 1L, default = 4L),
large_data_size = p_int(lower = 1L, default = 1e6, tags = c("train", "super")),
large_data_learner_ids = p_uty(tags = c("train", "super")),
large_data_nthread = p_int(lower = 1L, default = 4L, tags = c("train", "catboost", "lightgbm", "ranger", "xgboost")),
# small data
small_data_size = p_int(lower = 1L, default = 5000L),
small_data_resampling = p_uty(),
max_cardinality = p_int(lower = 1L, default = 100L),
extra_trees_max_cardinality = p_int(lower = 1L, default = 40L),
small_data_size = p_int(lower = 1L, default = 5000L, tags = c("train", "super")),
small_data_resampling = p_uty(tags = c("train", "super")),
# cardinality
max_cardinality = p_int(lower = 1L, default = 100L, tags = c("train", "super")),
extra_trees_max_cardinality = p_int(lower = 1L, default = 40L, tags = c("train", "extra_trees")),
# tuner
resampling = p_uty(),
terminator = p_uty(),
measure = p_uty(),
lhs_size = p_int(lower = 1L, default = 4L),
callbacks = p_uty(),
store_benchmark_result = p_lgl(default = FALSE))
resampling = p_uty(tags = c("train", "super")),
terminator = p_uty(tags = c("train", "super")),
measure = p_uty(tags = c("train", "super")),
lhs_size = p_int(lower = 1L, default = 4L, tags = c("train", "super")),
callbacks = p_uty(tags = c("train", "super")),
store_benchmark_result = p_lgl(default = FALSE, tags = c("train", "super")),
store_models = p_lgl(default = FALSE, tags = c("train", "super")))

param_set$set_values(
learner_ids = c("glmnet", "kknn", "lda", "nnet", "ranger", "svm", "xgboost", "catboost", "extra_trees", "lightgbm"),
learner_timeout = 900L,
max_nthread = 1L,
max_memory = 32000L,
large_data_size = 1e6L,
large_data_learner_ids = c("lda", "ranger", "xgboost", "catboost", "extra_trees", "lightgbm"),
large_data_learner_ids = intersect(c("lda", "ranger", "xgboost", "catboost", "extra_trees", "lightgbm"), private$.learner_ids),
large_data_nthread = 4L,
small_data_size = 5000L,
small_data_resampling = rsmp("cv", folds = 10L),
Expand All @@ -71,13 +77,19 @@ LearnerClassifAuto = R6Class("LearnerClassifAuto",
terminator = trm("run_time", secs = 14400L),
measure = msr("classif.ce"),
lhs_size = 4L,
store_benchmark_result = FALSE)
store_benchmark_result = FALSE,
store_models = FALSE)

# subset to relevant parameters for selected learners
param_set = param_set$subset(ids = unique(param_set$ids(any_tags = c("super", learner_ids))))

self$graph = build_graph(private$.learner_ids, "classif")

super$initialize(
id = id,
task_type = "classif",
param_set = param_set,
packages = c("mlr3tuning", "mlr3learners", "mlr3pipelines", "mlr3mbo", "mlr3automl", "xgboost", "catboost", "lightgbm", "ranger", "nnet", "kknn", "glmnet", "MASS", "e1071"),
packages = c("mlr3tuning", "mlr3learners", "mlr3pipelines", "mlr3mbo", "mlr3automl", self$graph$packages),
feature_types = c("logical", "integer", "numeric", "character", "factor"),
predict_types = c("response", "prob"),
properties = c("missings", "weights", "twoclass", "multiclass"),
Expand All @@ -86,8 +98,10 @@ LearnerClassifAuto = R6Class("LearnerClassifAuto",
),

private = list(
.learner_ids = NULL,

.train = function(task) {
train_auto(self, task, task_type = "classif")
train_auto(self, private, task)
},

.predict = function(task) {
Expand Down
21 changes: 1 addition & 20 deletions R/LearnerClassifAutoCatboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,7 @@ LearnerClassifAutoCatboost = R6Class("LearnerClassifAutoCatboost",
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id = "classif.auto_catboost") {
super$initialize(id = id)

# reduce parameter set to the relevant parameters
private$.param_set = private$.param_set$subset(
c("learner_ids",
"learner_timeout",
"catboost_eval_metric",
"small_data_size",
"small_data_resampling",
"max_cardinality",
"resampling",
"terminator",
"measure",
"lhs_size",
"callbacks",
"store_benchmark_result")
)

self$param_set$set_values(learner_ids = "catboost")
self$packages = c("mlr3tuning", "mlr3learners", "mlr3pipelines", "mlr3mbo", "mlr3automl", "catboost")
super$initialize(id = id, learner_ids = "catboost")
}
)
)
Expand Down
20 changes: 1 addition & 19 deletions R/LearnerClassifAutoGlmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,7 @@ LearnerClassifAutoGlmnet = R6Class("LearnerClassifAutoGlmnet",
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id = "classif.auto_glmnet") {
super$initialize(id = id)

# reduce parameter set to the relevant parameters
private$.param_set = private$.param_set$subset(
c("learner_ids",
"learner_timeout",
"small_data_size",
"small_data_resampling",
"max_cardinality",
"resampling",
"terminator",
"measure",
"lhs_size",
"callbacks",
"store_benchmark_result")
)

self$param_set$set_values(learner_ids = "glmnet")
self$packages = c("mlr3tuning", "mlr3learners", "mlr3pipelines", "mlr3mbo", "mlr3automl", "glmnet")
super$initialize(id = id, learner_ids = "glmnet")
}
)
)
Expand Down
20 changes: 1 addition & 19 deletions R/LearnerClassifAutoKKNN.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,7 @@ LearnerClassifAutoKKNN = R6Class("LearnerClassifAutoKKNN",
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id = "classif.auto_kknn") {
super$initialize(id = id)

# reduce parameter set to the relevant parameters
private$.param_set = private$.param_set$subset(
c("learner_ids",
"learner_timeout",
"small_data_size",
"small_data_resampling",
"max_cardinality",
"resampling",
"terminator",
"measure",
"lhs_size",
"callbacks",
"store_benchmark_result")
)

self$param_set$set_values(learner_ids = "kknn")
self$packages = c("mlr3tuning", "mlr3learners", "mlr3pipelines", "mlr3mbo", "mlr3automl", "kknn")
super$initialize(id = id, learner_ids = "kknn")
}
)
)
Expand Down
21 changes: 1 addition & 20 deletions R/LearnerClassifAutoLightGBM.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,7 @@ LearnerClassifAutoLightGBM = R6Class("LearnerClassifAutoLightGBM",
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id = "classif.auto_lightgbm") {
super$initialize(id = id)

# reduce parameter set to the relevant parameters
private$.param_set = private$.param_set$subset(
c("learner_ids",
"learner_timeout",
"lightgbm_eval_metric",
"small_data_size",
"small_data_resampling",
"max_cardinality",
"resampling",
"terminator",
"measure",
"lhs_size",
"callbacks",
"store_benchmark_result")
)

self$param_set$set_values(learner_ids = "lightgbm")
self$packages = c("mlr3tuning", "mlr3learners", "mlr3pipelines", "mlr3mbo", "mlr3automl", "lightgbm")
super$initialize(id = id, learner_ids = "lightgbm")
}
)
)
Expand Down
20 changes: 1 addition & 19 deletions R/LearnerClassifAutoNnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,7 @@ LearnerClassifAutoNnet = R6Class("LearnerClassifAutoNnet",
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id = "classif.auto_nnet") {
super$initialize(id = id)

# reduce parameter set to the relevant parameters
private$.param_set = private$.param_set$subset(
c("learner_ids",
"learner_timeout",
"small_data_size",
"small_data_resampling",
"max_cardinality",
"resampling",
"terminator",
"measure",
"lhs_size",
"callbacks",
"store_benchmark_result")
)

self$param_set$set_values(learner_ids = "nnet")
self$packages = c("mlr3tuning", "mlr3learners", "mlr3pipelines", "mlr3mbo", "mlr3automl", "nnet")
super$initialize(id = id, learner_ids = "nnet")
}
)
)
Expand Down
20 changes: 1 addition & 19 deletions R/LearnerClassifAutoRanger.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,7 @@ LearnerClassifAutoRanger = R6Class("LearnerClassifAutoRanger",
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id = "classif.auto_ranger") {
super$initialize(id = id)

# reduce parameter set to the relevant parameters
private$.param_set = private$.param_set$subset(
c("learner_ids",
"learner_timeout",
"small_data_size",
"small_data_resampling",
"max_cardinality",
"resampling",
"terminator",
"measure",
"lhs_size",
"callbacks",
"store_benchmark_result")
)

self$param_set$set_values(learner_ids = "ranger")
self$packages = c("mlr3tuning", "mlr3learners", "mlr3pipelines", "mlr3mbo", "mlr3automl", "ranger")
super$initialize(id = id, learner_ids = "ranger")
}
)
)
Expand Down
20 changes: 1 addition & 19 deletions R/LearnerClassifAutoSVM.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,7 @@ LearnerClassifAutoSVM = R6Class("LearnerClassifAutoSVM",
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id = "classif.auto_svm") {
super$initialize(id = id)

# reduce parameter set to the relevant parameters
private$.param_set = private$.param_set$subset(
c("learner_ids",
"learner_timeout",
"small_data_size",
"small_data_resampling",
"max_cardinality",
"resampling",
"terminator",
"measure",
"lhs_size",
"callbacks",
"store_benchmark_result")
)

self$param_set$set_values(learner_ids = "svm")
self$packages = c("mlr3tuning", "mlr3learners", "mlr3pipelines", "mlr3mbo", "mlr3automl", "e1071")
super$initialize(id = id, learner_ids = "svm")
}
)
)
Expand Down
Loading
Loading