Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow quantile predictions for regression #1086

Merged
merged 13 commits into from
Aug 20, 2024
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ importFrom(parallelly,availableCores)
importFrom(stats,contr.treatment)
importFrom(stats,model.frame)
importFrom(stats,predict)
importFrom(stats,quantile)
importFrom(stats,rnorm)
importFrom(stats,runif)
importFrom(stats,sd)
Expand Down
44 changes: 44 additions & 0 deletions R/LearnerRegr.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,49 @@ LearnerRegr = R6Class("LearnerRegr", inherit = Learner,
predict_types = predict_types, properties = properties, data_formats = data_formats, packages = packages,
label = label, man = man)
}
),

active = list(

#' @field quantiles (`numeric()`)\cr
#' Numeric vector of probabilities to be used while predicting quantiles.
#' Elements must be between 0 and 1, not missing and provided in ascending order.
#' If only one quantile is provided, it is used as response.
#' Otherwise, set `$quantile_response` to specify the response quantile.
quantiles = function(rhs) {
if (missing(rhs)) {
return(private$.quantiles)
}

if ("quantiles" %nin% self$predict_types) {
stopf("Learner does not support predicting quantiles")
}
private$.quantiles = assert_numeric(rhs, lower = 0, upper = 1, any.missing = FALSE, min.len = 1L, sorted = TRUE, .var.name = "quantiles")

if (length(private$.quantiles) == 1) {
private$.quantile_response = private$.quantiles
}
},

#' @field quantile_response (`numeric(1)`)\cr
#' The quantile to be used as response.
quantile_response = function(rhs) {
if (missing(rhs)) {
return(private$.quantile_response)
}

if ("quantiles" %nin% self$predict_types) {
stopf("Learner does not support predicting quantiles")
}

private$.quantile_response = assert_number(rhs, lower = 0, upper = 1, .var.name = "response")
private$.quantiles = sort(union(private$.quantiles, private$.quantile_response))
}
),


private = list(
.quantiles = NULL,
.quantile_response = NULL
)
)
18 changes: 15 additions & 3 deletions R/LearnerRegrDebug.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ LearnerRegrDebug = R6Class("LearnerRegrDebug", inherit = LearnerRegr,
super$initialize(
id = "regr.debug",
feature_types = c("logical", "integer", "numeric", "character", "factor", "ordered"),
predict_types = c("response", "se"),
predict_types = c("response", "se", "quantiles"),
param_set = ps(
predict_missing = p_dbl(0, 1, default = 0, tags = "predict"),
predict_missing_type = p_fct(c("na", "omit"), default = "na", tags = "predict"),
Expand All @@ -61,6 +61,12 @@ LearnerRegrDebug = R6Class("LearnerRegrDebug", inherit = LearnerRegr,
pid = Sys.getpid()
)

if (self$predict_type == "quantiles") {
probs = self$quantiles
model$quantiles = unname(quantile(truth, probs))
model$quantile_probs = probs
}

if (isTRUE(pv$save_tasks)) {
model$task_train = task$clone(deep = TRUE)
}
Expand All @@ -75,7 +81,14 @@ LearnerRegrDebug = R6Class("LearnerRegrDebug", inherit = LearnerRegr,
self$state$model$task_predict = task$clone(deep = TRUE)
}

prediction = named_list(mlr_reflections$learner_predict_types[["regr"]][[self$predict_type]])
if (self$predict_type == "quantiles") {
prediction = list(quantiles = matrix(self$model$quantiles, nrow = n, ncol = length(self$model$quantiles), byrow = TRUE))
attr(prediction$quantiles, "probs") = self$model$quantile_probs
attr(prediction$quantiles, "response") = self$quantile_response
return(prediction)
}

prediction = setdiff(named_list(mlr_reflections$learner_predict_types[["regr"]][[self$predict_type]]), "quantiles")
missing_type = pv$predict_missing_type %??% "na"

for (pt in names(prediction)) {
Expand All @@ -91,7 +104,6 @@ LearnerRegrDebug = R6Class("LearnerRegrDebug", inherit = LearnerRegr,
prediction[[pt]] = value
}


return(prediction)
}
)
Expand Down
37 changes: 35 additions & 2 deletions R/PredictionDataRegr.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ check_prediction_data.PredictionDataRegr = function(pdata, ...) { # nolint
pdata$row_ids = assert_row_ids(pdata$row_ids)
n = length(pdata$row_ids)
if (is.null(pdata$truth)) pdata$truth = NA_real_
if (!length(pdata$row_ids)) pdata$truth = numeric(0)
if (!length(pdata$row_ids)) pdata$truth = numeric()

if (!is.null(pdata$response)) {
pdata$response = assert_numeric(unname(pdata$response))
Expand All @@ -16,6 +16,27 @@ check_prediction_data.PredictionDataRegr = function(pdata, ...) { # nolint
assert_prediction_count(length(pdata$se), n, "se")
}

if (!is.null(pdata$quantiles)) {
quantiles = pdata$quantiles
assert_matrix(quantiles)
assert_prediction_count(nrow(quantiles), n, "quantiles")

if (is.null(attr(quantiles, "probs"))) {
stopf("No probs attribute stored in 'quantile'")
}

if (is.null(attr(quantiles, "response"))) {
stopf("No response attribute stored in 'quantile'")
}

if (any(apply(quantiles, 1L, is.unsorted))) {
stopf("Quantiles are not ascending with probabilities")
}

colnames(pdata$quantiles) = sprintf("q%g", attr(quantiles, "probs"))
attr(pdata$quantiles, "response") = sprintf("q%g", attr(quantiles, "response"))
}

if (!is.null(pdata$distr)) {
assert_class(pdata$distr, "VectorDistribution")

Expand Down Expand Up @@ -45,6 +66,10 @@ is_missing_prediction_data.PredictionDataRegr = function(pdata, ...) { # nolint
miss = miss | is.na(pdata$se)
}

if (!is.null(pdata$quantiles)) {
miss = miss | apply(pdata$quantiles, 1L, anyMissing)
}

pdata$row_ids[miss]
}

Expand All @@ -67,12 +92,16 @@ c.PredictionDataRegr = function(..., keep_duplicates = TRUE) { # nolint

elems = c("row_ids", "truth", intersect(predict_types[[1L]], c("response", "se")))
tab = map_dtr(dots, function(x) x[elems], .fill = FALSE)
quantiles = do.call(rbind, map(dots, "quantiles"))

if (!keep_duplicates) {
tab = unique(tab, by = "row_ids", fromLast = TRUE)
keep = !duplicated(tab, by = "row_ids", fromLast = TRUE)
tab = tab[keep]
quantiles = quantiles[keep, , drop = FALSE]
}

result = as.list(tab)
result$quantiles = quantiles

if ("distr" %in% predict_types[[1L]]) {
require_namespaces("distr6", msg = "To predict probability distributions, please install %s")
Expand All @@ -96,5 +125,9 @@ filter_prediction_data.PredictionDataRegr = function(pdata, row_ids, ...) {
pdata$se = pdata$se[keep]
}

if (!is.null(pdata$quantiles)) {
pdata$quantiles = pdata$quantiles[keep, , drop = FALSE]
}

pdata
}
30 changes: 27 additions & 3 deletions R/PredictionRegr.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,19 @@ PredictionRegr = R6Class("PredictionRegr", inherit = Prediction,
#' Numeric vector of predicted standard errors.
#' One element for each observation in the test set.
#'
#' @param quantiles (`matrix()`)\cr
#' Numeric matrix of predicted quantiles. One row per observation, one column per quantile.
#'
#' @param distr (`VectorDistribution`)\cr
#' `VectorDistribution` from package distr6 (in repository \url{https://raphaels1.r-universe.dev}).
#' Each individual distribution in the vector represents the random variable 'survival time'
#' for an individual observation.
#'
#' @param check (`logical(1)`)\cr
#' If `TRUE`, performs some argument checks and predict type conversions.
initialize = function(task = NULL, row_ids = task$row_ids, truth = task$truth(), response = NULL, se = NULL, distr = NULL, check = TRUE) {
initialize = function(task = NULL, row_ids = task$row_ids, truth = task$truth(), response = NULL, se = NULL, quantiles = NULL, distr = NULL, check = TRUE) {
pdata = new_prediction_data(
list(row_ids = row_ids, truth = truth, response = response, se = se, distr = distr),
list(row_ids = row_ids, truth = truth, response = response, se = se, quantiles = quantiles, distr = distr),
task_type = "regr"
)

Expand All @@ -56,7 +59,11 @@ PredictionRegr = R6Class("PredictionRegr", inherit = Prediction,
self$task_type = "regr"
self$man = "mlr3::PredictionRegr"
self$data = pdata
self$predict_types = intersect(c("response", "se", "distr"), names(pdata))
predict_types = intersect(names(mlr_reflections$learner_predict_types[["regr"]]), names(pdata))
# response is in saved in quantiles matrix
if ("quantiles" %in% predict_types) predict_types = union(predict_types, "response")
self$predict_types = predict_types
private$.quantile_response = attr(quantiles, "response")
}
),

Expand All @@ -65,6 +72,7 @@ PredictionRegr = R6Class("PredictionRegr", inherit = Prediction,
#' Access the stored predicted response.
response = function(rhs) {
assert_ro_binding(rhs)
if (!is.null(private$.quantile_response)) return(self$data$quantiles[, private$.quantile_response])
self$data$response %??% rep(NA_real_, length(self$data$row_ids))
},

Expand All @@ -75,6 +83,13 @@ PredictionRegr = R6Class("PredictionRegr", inherit = Prediction,
self$data$se %??% rep(NA_real_, length(self$data$row_ids))
},

#' @field quantiles (`matrix()`)\cr
#' Matrix of predicted quantiles. Observations are in rows, quantile (in ascending order) in columns.
quantiles = function(rhs) {
assert_ro_binding(rhs)
self$data$quantiles
},

#' @field distr (`VectorDistribution`)\cr
#' Access the stored vector distribution.
#' Requires package `distr6`(in repository \url{https://raphaels1.r-universe.dev}) .
Expand All @@ -84,6 +99,10 @@ PredictionRegr = R6Class("PredictionRegr", inherit = Prediction,
}
return(self$data$distr)
}
),

private = list(
.quantile_response = NULL
)
)

Expand All @@ -92,6 +111,11 @@ PredictionRegr = R6Class("PredictionRegr", inherit = Prediction,
as.data.table.PredictionRegr = function(x, ...) { # nolint
tab = as.data.table(x$data[c("row_ids", "truth", "response", "se")])

if ("quantiles" %in% x$predict_types) {
tab = rcbind(tab, as.data.table(x$data$quantiles))
set(tab, j = "response", value = x$response)
}

if ("distr" %in% x$predict_types) {
require_namespaces("distr6", msg = "To predict probability distributions, please install %s")
tab$distr = list(x$distr)
Expand Down
2 changes: 1 addition & 1 deletion R/mlr_reflections.R
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ local({

mlr_reflections$learner_predict_types = list(
classif = list(response = "response", prob = c("response", "prob")),
regr = list(response = "response", se = c("response", "se"), distr = c("response", "se", "distr"))
regr = list(response = "response", se = c("response", "se"), quantiles = c("response", "quantiles"), distr = c("response", "se", "distr"))
)

# Allowed tags for parameters
Expand Down
2 changes: 1 addition & 1 deletion R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#' @importFrom R6 R6Class is.R6
#' @importFrom utils data head tail getFromNamespace packageVersion
#' @importFrom graphics plot
#' @importFrom stats predict rnorm runif sd contr.treatment model.frame terms
#' @importFrom stats predict rnorm runif sd contr.treatment model.frame terms quantile
#' @importFrom uuid UUIDgenerate
#' @importFrom parallelly availableCores
#' @importFrom future nbrOfWorkers plan
Expand Down
4 changes: 4 additions & 0 deletions inst/testthat/helper_autotest.R
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,10 @@ run_autotest = function(learner, N = 30L, exclude = NULL, predict_types = learne
learner$id = sprintf("%s:%s", id, predict_type)
learner$predict_type = predict_type

if (predict_type == "quantiles") {
learner$quantiles = 0.5
}

run = run_experiment(task, learner)
if (!run$ok) {
return(run)
Expand Down
14 changes: 14 additions & 0 deletions man/LearnerRegr.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions man/PredictionRegr.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/mlr_learners_regr.debug.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

37 changes: 37 additions & 0 deletions tests/testthat/test_Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,43 @@ test_that("validation task with 0 observations", {
expect_error({learner$train(task)}, "has 0 observations")
})

test_that("quantiles in LearnerRegr", {
task = tsk("mtcars")
learner = lrn("regr.debug", predict_type = "quantiles")
expect_learner(learner)
quantiles = c(0.05, 0.5, 0.95)
learner$quantiles = quantiles

expect_numeric(learner$quantiles, any.missing = FALSE, len = 3)

learner$quantile_response = 0.6
expect_equal(learner$quantile_response, 0.6)
expect_equal(learner$quantiles, c(0.05, 0.5, 0.6, 0.95))

expect_error({
learner$quantiles = c(0.5, 0.1)
}, "sorted")

expect_error({
learner$quantiles = integer()
}, "length")

learner$train(task)

expect_numeric(learner$model$quantiles, len = 4L)

pred = learner$predict(task)
expect_prediction(pred)
expect_subset("quantiles", pred$predict_types)
expect_matrix(pred$quantiles, ncols = 4L, nrows = task$nrow, any.missing = FALSE)
expect_true(!any(apply(pred$quantiles, 1L, is.unsorted)))
expect_equal(pred$response, pred$quantiles[, 3L])

tab = as.data.table(pred)
expect_data_table(tab, nrows = task$nrow)
expect_subset("q0.5", names(tab))
})

test_that("predict time is cumulative", {
learner = lrn("classif.debug", sleep_predict = function() 0.05)
task = tsk("iris")
Expand Down