Skip to content

Commit

Permalink
refactor: only pass quantiles
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Aug 19, 2024
1 parent 85db887 commit dafd087
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 16 deletions.
25 changes: 11 additions & 14 deletions R/LearnerRegrRanger.R
Original file line number Diff line number Diff line change
Expand Up @@ -126,27 +126,24 @@ LearnerRegrRanger = R6Class("LearnerRegrRanger",
pv = self$param_set$get_values(tags = "predict")
newdata = ordered_features(task, self)

type = switch(
self$predict_type,
response = "response",
se = "se",
quantile = "quantiles"
)

prediction = invoke(predict, self$model, data = newdata, type = type, quantiles = private$.quantile, .args = pv)
prediction = invoke(predict, self$model,
data = newdata,
type = if (self$predict_type == "quantile") "quantiles" else pv$type,
quantiles = private$.quantile,
.args = pv)

if (type == "quantiles") {
response = prediction$predictions[, which(private$.quantile == private$.quantile_response)]
if (self$predict_type == "quantile") {
quantile = prediction$predictions
attr(quantile, "probs") = private$.quantile
list(response = response, quantile = quantile)
} else {
list(response = prediction$predictions, se = prediction$se)
attr(quantile, "response") = private$.quantile_response
return(list(quantile = quantile))
}

list(response = prediction$predictions, se = prediction$se)
},

.hotstart = function(task) {
model = self$model
model = self$models
model$num.trees = self$param_set$values$num.trees
model
}
Expand Down
8 changes: 6 additions & 2 deletions tests/testthat/test_regr_ranger.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,12 @@ test_that("quantile prediction", {

learner$train(task)
pred = learner$predict(task)
expect_matrix(pred$quantile)

expect_matrix(pred$quantile, ncol = 3L)
expect_true(!any(apply(pred$quantile, 1L, is.unsorted)))
expect_equal(pred$response, pred$quantile[, 2L])

tab = as.data.table(pred)
expect_names(names(tab), must.include = c("q0.1", "q0.5", "q0.9", "response"))
expect_names(names(tab), identical.to = c("row_ids", "truth", "q0.1", "q0.5", "q0.9", "response"))
expect_equal(tab$response, tab$q0.5)
})

0 comments on commit dafd087

Please sign in to comment.