Skip to content

Commit

Permalink
feat(Learner): support marshal property (#993)
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer authored Apr 23, 2024
1 parent f003de1 commit 6a11743
Show file tree
Hide file tree
Showing 38 changed files with 1,027 additions and 46 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ Collate:
'helper_hashes.R'
'helper_print.R'
'install_pkgs.R'
'marshal.R'
'mlr_sugar.R'
'mlr_test_helpers.R'
'partition.R'
Expand Down
13 changes: 13 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ S3method(fix_factor_levels,data.table)
S3method(head,Task)
S3method(is_missing_prediction_data,PredictionDataClassif)
S3method(is_missing_prediction_data,PredictionDataRegr)
S3method(marshal_model,classif.debug_model)
S3method(marshal_model,default)
S3method(marshal_model,learner_state)
S3method(partition,Task)
S3method(partition,TaskClassif)
S3method(partition,TaskRegr)
Expand All @@ -95,6 +98,7 @@ S3method(print,PredictionData)
S3method(print,benchmark_grid)
S3method(print,bmr_aggregate)
S3method(print,bmr_score)
S3method(print,marshaled)
S3method(print,rr_score)
S3method(rd_info,Learner)
S3method(rd_info,Measure)
Expand All @@ -104,6 +108,9 @@ S3method(set_threads,default)
S3method(set_threads,list)
S3method(summary,Task)
S3method(tail,Task)
S3method(unmarshal_model,classif.debug_model_marshaled)
S3method(unmarshal_model,default)
S3method(unmarshal_model,learner_state_marshaled)
export(BenchmarkResult)
export(DataBackend)
export(DataBackendDataTable)
Expand Down Expand Up @@ -207,9 +214,14 @@ export(default_measures)
export(extract_pkgs)
export(filter_prediction_data)
export(install_pkgs)
export(is_marshaled_model)
export(is_missing_prediction_data)
export(learner_marshal)
export(learner_marshaled)
export(learner_unmarshal)
export(lrn)
export(lrns)
export(marshal_model)
export(mlr_learners)
export(mlr_measures)
export(mlr_reflections)
Expand All @@ -227,6 +239,7 @@ export(tgen)
export(tgens)
export(tsk)
export(tsks)
export(unmarshal_model)
import(checkmate)
import(data.table)
import(mlr3misc)
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# mlr3 (development version)

* Feat: added support for `"marshal"` property, which allows learners to process
models so they can be serialized. This happens automatically during `resample()`
and `benchmark()`.
* Log encapsulated errors and warnings with the `lgr` package.

# mlr3 0.18.0
Expand Down
14 changes: 14 additions & 0 deletions R/BenchmarkResult.R
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,20 @@ BenchmarkResult = R6Class("BenchmarkResult",
invisible(self)
},

#' @description
#' Marshals all stored models.
#' @param ... (any)\cr
#' Additional arguments passed to [`marshal_model()`].
marshal = function(...) {
private$.data$marshal(...)
},
#' @description
#' Unmarshals all stored models.
#' @param ... (any)\cr
#' Additional arguments passed to [`unmarshal_model()`].
unmarshal = function(...) {
private$.data$unmarshal(...)
},

#' @description
#' Returns a table with one row for each resampling iteration, including
Expand Down
11 changes: 7 additions & 4 deletions R/HotstartStack.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,13 @@ HotstartStack = R6Class("HotstartStack",
add = function(learners) {
learners = assert_learners(as_learners(learners))

# check for models
if (any(map_lgl(learners, function(learner) is.null(learner$state$model)))) {
stopf("Learners must be trained before adding them to the hotstart stack.")
}
walk(learners, function(learner) {
if (is.null(learner$model)) {
stopf("Learners must be trained before adding them to the hotstart stack.")
} else if (is_marshaled_model(learner$model)) {
stopf("Learners must be unmarshaled before adding them to the hotstart stack.")
}
})

if (!is.null(self$hotstart_threshold)) {
learners = keep(learners, function(learner) {
Expand Down
33 changes: 30 additions & 3 deletions R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ Learner = R6Class("Learner",
#' @param ... (ignored).
print = function(...) {
catn(format(self), if (is.null(self$label) || is.na(self$label)) "" else paste0(": ", self$label))
catn(str_indent("* Model:", if (is.null(self$model)) "-" else class(self$model)[1L]))
catn(str_indent("* Model:", if (is.null(self$model)) "-" else if (is_marshaled_model(self$model)) "<marshaled>" else paste0(class(self$model)[1L])))
catn(str_indent("* Parameters:", as_short_string(self$param_set$values, 1000L)))
catn(str_indent("* Packages:", self$packages))
catn(str_indent("* Predict Types: ", replace(self$predict_types, self$predict_types == self$predict_type, paste0("[", self$predict_type, "]"))))
Expand Down Expand Up @@ -243,6 +243,7 @@ Learner = R6Class("Learner",
test_row_ids = task$row_roles$test

learner_train(learner, task, train_row_ids = train_row_ids, test_row_ids = test_row_ids, mode = mode)
self$model = unmarshal_model(model = self$state$model, inplace = TRUE)

# store data prototype
proto = task$data(rows = integer())
Expand Down Expand Up @@ -279,6 +280,10 @@ Learner = R6Class("Learner",
stopf("Cannot predict, Learner '%s' has not been trained yet", self$id)
}

if (is_marshaled_model(self$model)) {
stopf("Cannot predict, Learner '%s' has not been unmarshaled yet", self$id)
}

if (isTRUE(self$parallel_predict) && nbrOfWorkers() > 1L) {
row_ids = row_ids %??% task$row_ids
chunked = chunk_vector(row_ids, n_chunks = nbrOfWorkers(), shuffle = FALSE)
Expand Down Expand Up @@ -388,7 +393,6 @@ Learner = R6Class("Learner",
self$state$model
},


#' @field timings (named `numeric(2)`)\cr
#' Elapsed time in seconds for the steps `"train"` and `"predict"`.
#' Measured via [mlr3misc::encapsulate()].
Expand Down Expand Up @@ -541,7 +545,6 @@ Learner = R6Class("Learner",
)
)


#' @export
rd_info.Learner = function(obj, ...) {
x = c("",
Expand Down Expand Up @@ -576,3 +579,27 @@ default_values.Learner = function(x, search_space, task, ...) { # nolint
# format_list_item.Learner = function(x, ...) { # nolint
# sprintf("<lrn:%s>", x$id)
# }


#' @export
marshal_model.learner_state = function(model, inplace = FALSE, ...) {
if (is.null(model$model)) {
return(model)
}
mm = marshal_model(model$model, inplace = inplace, ...)
if (!is_marshaled_model(mm)) {
return(model)
}
model$model = mm
structure(list(
marshaled = model,
packages = "mlr3"
), class = c("learner_state_marshaled", "list_marshaled", "marshaled"))
}

#' @export
unmarshal_model.learner_state_marshaled = function(model, inplace = FALSE, ...) {
mm = model$marshaled
mm$model = unmarshal_model(mm$model, inplace = inplace, ...)
return(mm)
}
86 changes: 66 additions & 20 deletions R/LearnerClassifDebug.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#' \item{warning_train:}{Probability to signal a warning during train.}
#' \item{x:}{Numeric tuning parameter. Has no effect.}
#' \item{iter:}{Integer parameter for testing hotstarting.}
#' \item{count_marshaling:}{If `TRUE`, `marshal_model` will increase the `marshal_count` by 1 each time it is called. The default is `FALSE`.}
#' }
#' Note that segfaults may not be triggered reliably on your operating system.
#' Also note that if they work as intended, they will tear down your R session immediately!
Expand All @@ -49,39 +50,62 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
param_set = ps(
error_predict = p_dbl(0, 1, default = 0, tags = "predict"),
error_train = p_dbl(0, 1, default = 0, tags = "train"),
message_predict = p_dbl(0, 1, default = 0, tags = "predict"),
message_train = p_dbl(0, 1, default = 0, tags = "train"),
predict_missing = p_dbl(0, 1, default = 0, tags = "predict"),
predict_missing_type = p_fct(c("na", "omit"), default = "na", tags = "predict"),
save_tasks = p_lgl(default = FALSE, tags = c("train", "predict")),
segfault_predict = p_dbl(0, 1, default = 0, tags = "predict"),
segfault_train = p_dbl(0, 1, default = 0, tags = "train"),
sleep_train = p_uty(tags = "train"),
sleep_predict = p_uty(tags = "predict"),
threads = p_int(1L, tags = c("train", "threads")),
warning_predict = p_dbl(0, 1, default = 0, tags = "predict"),
warning_train = p_dbl(0, 1, default = 0, tags = "train"),
x = p_dbl(0, 1, tags = "train"),
iter = p_int(1, default = 1, tags = c("train", "hotstart")),
count_marshaling = p_lgl(default = FALSE, tags = "train")
)
super$initialize(
id = "classif.debug",
param_set = param_set,
feature_types = c("logical", "integer", "numeric", "character", "factor", "ordered"),
predict_types = c("response", "prob"),
param_set = ps(
error_predict = p_dbl(0, 1, default = 0, tags = "predict"),
error_train = p_dbl(0, 1, default = 0, tags = "train"),
message_predict = p_dbl(0, 1, default = 0, tags = "predict"),
message_train = p_dbl(0, 1, default = 0, tags = "train"),
predict_missing = p_dbl(0, 1, default = 0, tags = "predict"),
predict_missing_type = p_fct(c("na", "omit"), default = "na", tags = "predict"),
save_tasks = p_lgl(default = FALSE, tags = c("train", "predict")),
segfault_predict = p_dbl(0, 1, default = 0, tags = "predict"),
segfault_train = p_dbl(0, 1, default = 0, tags = "train"),
sleep_train = p_uty(tags = "train"),
sleep_predict = p_uty(tags = "predict"),
threads = p_int(1L, tags = c("train", "threads")),
warning_predict = p_dbl(0, 1, default = 0, tags = "predict"),
warning_train = p_dbl(0, 1, default = 0, tags = "train"),
x = p_dbl(0, 1, tags = "train"),
iter = p_int(1, default = 1, tags = c("train", "hotstart"))
),
properties = c("twoclass", "multiclass", "missings", "hotstart_forward"),
properties = c("twoclass", "multiclass", "missings", "hotstart_forward", "marshal"),
man = "mlr3::mlr_learners_classif.debug",
data_formats = c("data.table", "Matrix"),
label = "Debug Learner for Classification"
)
},
#' @description
#' Marshal the learner's model.
#' @param ... (any)\cr
#' Additional arguments passed to [`marshal_model()`].
marshal = function(...) {
learner_marshal(.learner = self, ...)
},
#' @description
#' Unmarshal the learner's model.
#' @param ... (any)\cr
#' Additional arguments passed to [`unmarshal_model()`].
unmarshal = function(...) {
learner_unmarshal(.learner = self, ...)
}
),
active = list(
#' @field marshaled (logical(1))\cr
#' Whether the learner has been marshaled.
marshaled = function() {
learner_marshaled(self)
}
),

private = list(
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
pv$count_marshaling = pv$count_marshaling %??% FALSE
roll = function(name) {
name %in% names(pv) && pv[[name]] > runif(1L)
}
Expand Down Expand Up @@ -110,6 +134,10 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
model$task_train = task$clone(deep = TRUE)
}

if (isTRUE(pv$count_marshaling)) {
model$marshal_count = 0L
}

set_class(model, "classif.debug_model")
},

Expand Down Expand Up @@ -193,3 +221,21 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,

#' @include mlr_learners.R
mlr_learners$add("classif.debug", function() LearnerClassifDebug$new())

#' @export
#' @method marshal_model classif.debug_model
marshal_model.classif.debug_model = function(model, inplace = FALSE, ...) {
if (!is.null(model$marshal_count)) {
model$marshal_count = model$marshal_count + 1
}
structure(list(
marshaled = model, packages = "mlr3"),
class = c("classif.debug_model_marshaled", "marshaled")
)
}

#' @export
#' @method unmarshal_model classif.debug_model_marshaled
unmarshal_model.classif.debug_model_marshaled = function(model, inplace = FALSE, ...) {
model$marshaled
}
1 change: 0 additions & 1 deletion R/LearnerRegrDebug.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ LearnerRegrDebug = R6Class("LearnerRegrDebug", inherit = LearnerRegr,
)
}
),

private = list(
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
Expand Down
4 changes: 4 additions & 0 deletions R/Measure.R
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ Measure = R6Class("Measure",
assert_measure(self, task = task, learner = learner)
assert_prediction(prediction)


if ("requires_task" %in% self$properties && is.null(task)) {
stopf("Measure '%s' requires a task", self$id)
}
Expand All @@ -184,6 +185,9 @@ Measure = R6Class("Measure",
if ("requires_model" %in% self$properties && (is.null(learner) || is.null(learner$model))) {
stopf("Measure '%s' requires the trained model", self$id)
}
if ("requires_model" %in% self$properties && is_marshaled_model(learner$model)) {
stopf("Measure '%s' requires the trained model, but model is in marshaled form", self$id)
}

if ("requires_train_set" %in% self$properties && is.null(train_set)) {
stopf("Measure '%s' requires the train_set", self$id)
Expand Down
15 changes: 15 additions & 0 deletions R/ResampleResult.R
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,21 @@ ResampleResult = R6Class("ResampleResult",
#' the object in its previous state.
discard = function(backends = FALSE, models = FALSE) {
private$.data$discard(backends = backends, models = models)
},

#' @description
#' Marshals all stored models.
#' @param ... (any)\cr
#' Additional arguments passed to [`marshal_model()`].
marshal = function(...) {
private$.data$marshal(...)
},
#' @description
#' Unmarshals all stored models.
#' @param ... (any)\cr
#' Additional arguments passed to [`unmarshal_model()`].
unmarshal = function(...) {
private$.data$unmarshal(...)
}
),

Expand Down
21 changes: 21 additions & 0 deletions R/ResultData.R
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,27 @@ ResultData = R6Class("ResultData",
invisible(self)
},

#' @description
#' Marshals all stored learner models.
#' This will do nothing to models that are already marshaled.
#' @param ... (any)\cr
#' Additional arguments passed to [`marshal_model()`].
marshal = function(...) {
learner_state = NULL
self$data$fact[, learner_state := lapply(learner_state, function(x) marshal_state_if_model(.state = x, inplace = TRUE, ...))]
invisible(self)
},
#' @description
#' Unmarshals all stored learner models.
#' This will do nothing to models which are not marshaled.
#' @param ... (any)\cr
#' Additional arguments passed to [`unmarshal_model()`].
unmarshal = function(...) {
learner_state = NULL
self$data$fact[, learner_state := lapply(learner_state, function(x) unmarshal_state_if_model(.state = x, inplace = TRUE, ...))]
invisible(self)
},

#' @description
#' Shrinks the object by discarding parts of the stored data.
#'
Expand Down
Loading

0 comments on commit 6a11743

Please sign in to comment.