Skip to content

Commit

Permalink
feat: allow quantile predictions for regression (#1086)
Browse files Browse the repository at this point in the history
* ...

* initial quantile support

* improve tests

* feat: add quantile_response

* fix: autotest

* refactor: store quantil response column name instead of response vector

* fix: attributes

* refactor: rename quantile to quantiles

* fix: autotest for single quantile

* docs: predictiondata

* fix: prediction predict type

---------

Co-authored-by: be-marc <[email protected]>
  • Loading branch information
mllg and be-marc authored Aug 20, 2024
1 parent 1e6bbef commit 9c6d1e3
Show file tree
Hide file tree
Showing 12 changed files with 187 additions and 11 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ importFrom(parallelly,availableCores)
importFrom(stats,contr.treatment)
importFrom(stats,model.frame)
importFrom(stats,predict)
importFrom(stats,quantile)
importFrom(stats,rnorm)
importFrom(stats,runif)
importFrom(stats,sd)
Expand Down
44 changes: 44 additions & 0 deletions R/LearnerRegr.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,49 @@ LearnerRegr = R6Class("LearnerRegr", inherit = Learner,
predict_types = predict_types, properties = properties, data_formats = data_formats, packages = packages,
label = label, man = man)
}
),

active = list(

#' @field quantiles (`numeric()`)\cr
#' Numeric vector of probabilities to be used while predicting quantiles.
#' Elements must be between 0 and 1, not missing and provided in ascending order.
#' If only one quantile is provided, it is used as response.
#' Otherwise, set `$quantile_response` to specify the response quantile.
quantiles = function(rhs) {
if (missing(rhs)) {
return(private$.quantiles)
}

if ("quantiles" %nin% self$predict_types) {
stopf("Learner does not support predicting quantiles")
}
private$.quantiles = assert_numeric(rhs, lower = 0, upper = 1, any.missing = FALSE, min.len = 1L, sorted = TRUE, .var.name = "quantiles")

if (length(private$.quantiles) == 1) {
private$.quantile_response = private$.quantiles
}
},

#' @field quantile_response (`numeric(1)`)\cr
#' The quantile to be used as response.
quantile_response = function(rhs) {
if (missing(rhs)) {
return(private$.quantile_response)
}

if ("quantiles" %nin% self$predict_types) {
stopf("Learner does not support predicting quantiles")
}

private$.quantile_response = assert_number(rhs, lower = 0, upper = 1, .var.name = "response")
private$.quantiles = sort(union(private$.quantiles, private$.quantile_response))
}
),


private = list(
.quantiles = NULL,
.quantile_response = NULL
)
)
18 changes: 15 additions & 3 deletions R/LearnerRegrDebug.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ LearnerRegrDebug = R6Class("LearnerRegrDebug", inherit = LearnerRegr,
super$initialize(
id = "regr.debug",
feature_types = c("logical", "integer", "numeric", "character", "factor", "ordered"),
predict_types = c("response", "se"),
predict_types = c("response", "se", "quantiles"),
param_set = ps(
predict_missing = p_dbl(0, 1, default = 0, tags = "predict"),
predict_missing_type = p_fct(c("na", "omit"), default = "na", tags = "predict"),
Expand All @@ -61,6 +61,12 @@ LearnerRegrDebug = R6Class("LearnerRegrDebug", inherit = LearnerRegr,
pid = Sys.getpid()
)

if (self$predict_type == "quantiles") {
probs = self$quantiles
model$quantiles = unname(quantile(truth, probs))
model$quantile_probs = probs
}

if (isTRUE(pv$save_tasks)) {
model$task_train = task$clone(deep = TRUE)
}
Expand All @@ -75,7 +81,14 @@ LearnerRegrDebug = R6Class("LearnerRegrDebug", inherit = LearnerRegr,
self$state$model$task_predict = task$clone(deep = TRUE)
}

prediction = named_list(mlr_reflections$learner_predict_types[["regr"]][[self$predict_type]])
if (self$predict_type == "quantiles") {
prediction = list(quantiles = matrix(self$model$quantiles, nrow = n, ncol = length(self$model$quantiles), byrow = TRUE))
attr(prediction$quantiles, "probs") = self$model$quantile_probs
attr(prediction$quantiles, "response") = self$quantile_response
return(prediction)
}

prediction = setdiff(named_list(mlr_reflections$learner_predict_types[["regr"]][[self$predict_type]]), "quantiles")
missing_type = pv$predict_missing_type %??% "na"

for (pt in names(prediction)) {
Expand All @@ -91,7 +104,6 @@ LearnerRegrDebug = R6Class("LearnerRegrDebug", inherit = LearnerRegr,
prediction[[pt]] = value
}


return(prediction)
}
)
Expand Down
37 changes: 35 additions & 2 deletions R/PredictionDataRegr.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ check_prediction_data.PredictionDataRegr = function(pdata, ...) { # nolint
pdata$row_ids = assert_row_ids(pdata$row_ids)
n = length(pdata$row_ids)
if (is.null(pdata$truth)) pdata$truth = NA_real_
if (!length(pdata$row_ids)) pdata$truth = numeric(0)
if (!length(pdata$row_ids)) pdata$truth = numeric()

if (!is.null(pdata$response)) {
pdata$response = assert_numeric(unname(pdata$response))
Expand All @@ -16,6 +16,27 @@ check_prediction_data.PredictionDataRegr = function(pdata, ...) { # nolint
assert_prediction_count(length(pdata$se), n, "se")
}

if (!is.null(pdata$quantiles)) {
quantiles = pdata$quantiles
assert_matrix(quantiles)
assert_prediction_count(nrow(quantiles), n, "quantiles")

if (is.null(attr(quantiles, "probs"))) {
stopf("No probs attribute stored in 'quantile'")
}

if (is.null(attr(quantiles, "response"))) {
stopf("No response attribute stored in 'quantile'")
}

if (any(apply(quantiles, 1L, is.unsorted))) {
stopf("Quantiles are not ascending with probabilities")
}

colnames(pdata$quantiles) = sprintf("q%g", attr(quantiles, "probs"))
attr(pdata$quantiles, "response") = sprintf("q%g", attr(quantiles, "response"))
}

if (!is.null(pdata$distr)) {
assert_class(pdata$distr, "VectorDistribution")

Expand Down Expand Up @@ -45,6 +66,10 @@ is_missing_prediction_data.PredictionDataRegr = function(pdata, ...) { # nolint
miss = miss | is.na(pdata$se)
}

if (!is.null(pdata$quantiles)) {
miss = miss | apply(pdata$quantiles, 1L, anyMissing)
}

pdata$row_ids[miss]
}

Expand All @@ -67,12 +92,16 @@ c.PredictionDataRegr = function(..., keep_duplicates = TRUE) { # nolint

elems = c("row_ids", "truth", intersect(predict_types[[1L]], c("response", "se")))
tab = map_dtr(dots, function(x) x[elems], .fill = FALSE)
quantiles = do.call(rbind, map(dots, "quantiles"))

if (!keep_duplicates) {
tab = unique(tab, by = "row_ids", fromLast = TRUE)
keep = !duplicated(tab, by = "row_ids", fromLast = TRUE)
tab = tab[keep]
quantiles = quantiles[keep, , drop = FALSE]
}

result = as.list(tab)
result$quantiles = quantiles

if ("distr" %in% predict_types[[1L]]) {
require_namespaces("distr6", msg = "To predict probability distributions, please install %s")
Expand All @@ -96,5 +125,9 @@ filter_prediction_data.PredictionDataRegr = function(pdata, row_ids, ...) {
pdata$se = pdata$se[keep]
}

if (!is.null(pdata$quantiles)) {
pdata$quantiles = pdata$quantiles[keep, , drop = FALSE]
}

pdata
}
30 changes: 27 additions & 3 deletions R/PredictionRegr.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,19 @@ PredictionRegr = R6Class("PredictionRegr", inherit = Prediction,
#' Numeric vector of predicted standard errors.
#' One element for each observation in the test set.
#'
#' @param quantiles (`matrix()`)\cr
#' Numeric matrix of predicted quantiles. One row per observation, one column per quantile.
#'
#' @param distr (`VectorDistribution`)\cr
#' `VectorDistribution` from package distr6 (in repository \url{https://raphaels1.r-universe.dev}).
#' Each individual distribution in the vector represents the random variable 'survival time'
#' for an individual observation.
#'
#' @param check (`logical(1)`)\cr
#' If `TRUE`, performs some argument checks and predict type conversions.
initialize = function(task = NULL, row_ids = task$row_ids, truth = task$truth(), response = NULL, se = NULL, distr = NULL, check = TRUE) {
initialize = function(task = NULL, row_ids = task$row_ids, truth = task$truth(), response = NULL, se = NULL, quantiles = NULL, distr = NULL, check = TRUE) {
pdata = new_prediction_data(
list(row_ids = row_ids, truth = truth, response = response, se = se, distr = distr),
list(row_ids = row_ids, truth = truth, response = response, se = se, quantiles = quantiles, distr = distr),
task_type = "regr"
)

Expand All @@ -56,7 +59,11 @@ PredictionRegr = R6Class("PredictionRegr", inherit = Prediction,
self$task_type = "regr"
self$man = "mlr3::PredictionRegr"
self$data = pdata
self$predict_types = intersect(c("response", "se", "distr"), names(pdata))
predict_types = intersect(names(mlr_reflections$learner_predict_types[["regr"]]), names(pdata))
# response is in saved in quantiles matrix
if ("quantiles" %in% predict_types) predict_types = union(predict_types, "response")
self$predict_types = predict_types
private$.quantile_response = attr(quantiles, "response")
}
),

Expand All @@ -65,6 +72,7 @@ PredictionRegr = R6Class("PredictionRegr", inherit = Prediction,
#' Access the stored predicted response.
response = function(rhs) {
assert_ro_binding(rhs)
if (!is.null(private$.quantile_response)) return(self$data$quantiles[, private$.quantile_response])
self$data$response %??% rep(NA_real_, length(self$data$row_ids))
},

Expand All @@ -75,6 +83,13 @@ PredictionRegr = R6Class("PredictionRegr", inherit = Prediction,
self$data$se %??% rep(NA_real_, length(self$data$row_ids))
},

#' @field quantiles (`matrix()`)\cr
#' Matrix of predicted quantiles. Observations are in rows, quantile (in ascending order) in columns.
quantiles = function(rhs) {
assert_ro_binding(rhs)
self$data$quantiles
},

#' @field distr (`VectorDistribution`)\cr
#' Access the stored vector distribution.
#' Requires package `distr6`(in repository \url{https://raphaels1.r-universe.dev}) .
Expand All @@ -84,6 +99,10 @@ PredictionRegr = R6Class("PredictionRegr", inherit = Prediction,
}
return(self$data$distr)
}
),

private = list(
.quantile_response = NULL
)
)

Expand All @@ -92,6 +111,11 @@ PredictionRegr = R6Class("PredictionRegr", inherit = Prediction,
as.data.table.PredictionRegr = function(x, ...) { # nolint
tab = as.data.table(x$data[c("row_ids", "truth", "response", "se")])

if ("quantiles" %in% x$predict_types) {
tab = rcbind(tab, as.data.table(x$data$quantiles))
set(tab, j = "response", value = x$response)
}

if ("distr" %in% x$predict_types) {
require_namespaces("distr6", msg = "To predict probability distributions, please install %s")
tab$distr = list(x$distr)
Expand Down
2 changes: 1 addition & 1 deletion R/mlr_reflections.R
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ local({

mlr_reflections$learner_predict_types = list(
classif = list(response = "response", prob = c("response", "prob")),
regr = list(response = "response", se = c("response", "se"), distr = c("response", "se", "distr"))
regr = list(response = "response", se = c("response", "se"), quantiles = c("response", "quantiles"), distr = c("response", "se", "distr"))
)

# Allowed tags for parameters
Expand Down
2 changes: 1 addition & 1 deletion R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#' @importFrom R6 R6Class is.R6
#' @importFrom utils data head tail getFromNamespace packageVersion
#' @importFrom graphics plot
#' @importFrom stats predict rnorm runif sd contr.treatment model.frame terms
#' @importFrom stats predict rnorm runif sd contr.treatment model.frame terms quantile
#' @importFrom uuid UUIDgenerate
#' @importFrom parallelly availableCores
#' @importFrom future nbrOfWorkers plan
Expand Down
4 changes: 4 additions & 0 deletions inst/testthat/helper_autotest.R
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,10 @@ run_autotest = function(learner, N = 30L, exclude = NULL, predict_types = learne
learner$id = sprintf("%s:%s", id, predict_type)
learner$predict_type = predict_type

if (predict_type == "quantiles") {
learner$quantiles = 0.5
}

run = run_experiment(task, learner)
if (!run$ok) {
return(run)
Expand Down
14 changes: 14 additions & 0 deletions man/LearnerRegr.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions man/PredictionRegr.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/mlr_learners_regr.debug.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

37 changes: 37 additions & 0 deletions tests/testthat/test_Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,43 @@ test_that("validation task with 0 observations", {
expect_error({learner$train(task)}, "has 0 observations")
})

test_that("quantiles in LearnerRegr", {
task = tsk("mtcars")
learner = lrn("regr.debug", predict_type = "quantiles")
expect_learner(learner)
quantiles = c(0.05, 0.5, 0.95)
learner$quantiles = quantiles

expect_numeric(learner$quantiles, any.missing = FALSE, len = 3)

learner$quantile_response = 0.6
expect_equal(learner$quantile_response, 0.6)
expect_equal(learner$quantiles, c(0.05, 0.5, 0.6, 0.95))

expect_error({
learner$quantiles = c(0.5, 0.1)
}, "sorted")

expect_error({
learner$quantiles = integer()
}, "length")

learner$train(task)

expect_numeric(learner$model$quantiles, len = 4L)

pred = learner$predict(task)
expect_prediction(pred)
expect_subset("quantiles", pred$predict_types)
expect_matrix(pred$quantiles, ncols = 4L, nrows = task$nrow, any.missing = FALSE)
expect_true(!any(apply(pred$quantiles, 1L, is.unsorted)))
expect_equal(pred$response, pred$quantiles[, 3L])

tab = as.data.table(pred)
expect_data_table(tab, nrows = task$nrow)
expect_subset("q0.5", names(tab))
})

test_that("predict time is cumulative", {
learner = lrn("classif.debug", sleep_predict = function() 0.05)
task = tsk("iris")
Expand Down

0 comments on commit 9c6d1e3

Please sign in to comment.