Skip to content

Commit

Permalink
Refactor: avoid confusing predict_set arg in functions (#1085)
Browse files Browse the repository at this point in the history
  • Loading branch information
mllg authored Aug 18, 2024
1 parent 3f159d9 commit 1e6bbef
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 59 deletions.
25 changes: 21 additions & 4 deletions R/BenchmarkResult.R
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,17 @@ BenchmarkResult = R6Class("BenchmarkResult",
#' Adds condition messages (`"warnings"`, `"errors"`) as extra
#' list columns of character vectors to the returned table
#'
#' @template param_predict_sets
#' @param predictions (`logical(1)`)\cr
#' Additionally return prediction objects, one column for each `predict_set` of all learners combined.
#' Columns are named `"prediction_train"`, `"prediction_test"` and `"prediction_internal_valid"`,
#' if present.
#'
#' @return [data.table::data.table()].
score = function(measures = NULL, ids = TRUE, conditions = FALSE, predict_sets = "test") {
score = function(measures = NULL, ids = TRUE, conditions = FALSE, predictions = FALSE) {
measures = as_measures(measures, task_type = self$task_type)
assert_flag(ids)
assert_flag(conditions)
assert_flag(predictions)

tab = score_measures(self, measures, view = NULL)
tab = merge(private$.data$data$uhashes, tab, by = "uhash", sort = FALSE)
Expand All @@ -191,12 +195,25 @@ BenchmarkResult = R6Class("BenchmarkResult",
set(tab, j = "errors", value = map(tab$learner, "errors"))
}

set(tab, j = "prediction", value = as_predictions(tab$prediction, predict_sets))
if (predictions) {
predict_sets = intersect(
mlr_reflections$predict_sets,
unlist(map(self$learners$learner, "predict_sets"), use.names = FALSE)
)
predict_cols = sprintf("prediction_%s", predict_sets)
for (i in seq_along(predict_sets)) {
set(tab, j = predict_cols[i],
value = map(tab$prediction, function(p) as_prediction(p[[predict_sets[i]]], check = FALSE))
)
}
} else {
predict_cols = character()
}

set_data_table_class(tab, "bmr_score")

cns = c("uhash", "nr", "task", "task_id", "learner", "learner_id", "resampling", "resampling_id",
"iteration", "prediction", "warnings", "errors", ids(measures))
"iteration", predict_cols, "warnings", "errors", ids(measures))
cns = intersect(cns, names(tab))
tab[, cns, with = FALSE]
},
Expand Down
38 changes: 26 additions & 12 deletions R/ResampleResult.R
Original file line number Diff line number Diff line change
Expand Up @@ -111,18 +111,19 @@ ResampleResult = R6Class("ResampleResult",
#' `$prediction()`.
#'
#' @param predict_sets (`character()`)\cr
#' Subset of `{"train", "test"}`.
#' Subset of `{"train", "test", "internal_valid"}`.
#' @return List of [Prediction] objects, one per element in `predict_sets`.
predictions = function(predict_sets = "test") {
assert_subset(predict_sets, mlr_reflections$predict_sets, empty.ok = FALSE)
private$.data$predictions(private$.view, predict_sets)
},

#' @description
#' Returns a table with one row for each resampling iteration, including all involved objects:
#' [Task], [Learner], [Resampling], iteration number (`integer(1)`), and [Prediction].
#' [Task], [Learner], [Resampling], iteration number (`integer(1)`), and (if enabled)
#' one [Prediction] for each predict set of the [Learner].
#' Additionally, a column with the individual (per resampling iteration) performance is added
#' for each [Measure] in `measures`,
#' named with the id of the respective measure id.
#' for each [Measure] in `measures`, named with the id of the respective measure id.
#' If `measures` is `NULL`, `measures` defaults to the return value of [default_measures()].
#'
#' @param ids (`logical(1)`)\cr
Expand All @@ -134,16 +135,17 @@ ResampleResult = R6Class("ResampleResult",
#' Adds condition messages (`"warnings"`, `"errors"`) as extra
#' list columns of character vectors to the returned table
#'
#' @param predict_sets (`character()`)\cr
#' Vector of predict sets (`{"train", "test"}`) to construct the [Prediction] objects from.
#' Default is `"test"`.
#' @param predictions (`logical(1)`)\cr
#' Additionally return prediction objects, one column for each `predict_set` of the learner.
#' Columns are named `"prediction_train"`, `"prediction_test"` and `"prediction_internal_valid"`,
#' if present.
#'
#' @return [data.table::data.table()].
score = function(measures = NULL, ids = TRUE, conditions = FALSE, predict_sets = "test") {
score = function(measures = NULL, ids = TRUE, conditions = FALSE, predictions = FALSE) {
measures = as_measures(measures, task_type = private$.data$task_type)
assert_flag(ids)
assert_flag(conditions)
assert_subset(predict_sets, mlr_reflections$predict_sets)
assert_flag(predictions)

tab = score_measures(self, measures, view = private$.view)

Expand All @@ -160,12 +162,22 @@ ResampleResult = R6Class("ResampleResult",
set(tab, j = "errors", value = map(tab$learner, "errors"))
}

set(tab, j = "prediction", value = as_predictions(tab$prediction, predict_sets))
if (predictions) {
predict_sets = intersect(mlr_reflections$predict_sets, tab$learner[[1L]]$predict_sets)
predict_cols = sprintf("prediction_%s", predict_sets)
for (i in seq_along(predict_sets)) {
set(tab, j = predict_cols[i],
value = map(tab$prediction, function(p) as_prediction(p[[predict_sets[i]]], check = FALSE))
)
}
} else {
predict_cols = character()
}

set_data_table_class(tab, "rr_score")

cns = c("task", "task_id", "learner", "learner_id", "resampling", "resampling_id", "iteration",
"prediction", "warnings", "errors", ids(measures))
predict_cols, "warnings", "errors", ids(measures))
cns = intersect(cns, names(tab))
tab[, cns, with = FALSE]
},
Expand All @@ -179,6 +191,7 @@ ResampleResult = R6Class("ResampleResult",
#' `NA` values.
#' Note that some measures such as RMSE, do have an `$obs_loss`, but they require an
#' additional transformation after aggregation, in this example taking the square-root.
#'
#' @param predict_sets (`character()`)\cr
#' The predict sets.
obs_loss = function(measures = NULL, predict_sets = "test") {
Expand Down Expand Up @@ -384,5 +397,6 @@ resample_result_aggregate = function(rr, measures) {

#' @export
print.rr_score = function(x, ...) {
print_data_table(x, c("task", "learner", "resampling", "prediction"))
predict_cols = sprintf("prediction_%s", mlr_reflections$predict_sets)
print_data_table(x, c("task", "learner", "resampling", predict_cols))
}
15 changes: 9 additions & 6 deletions R/as_prediction.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,21 @@ as_predictions = function(x, predict_sets = "test", ...) {
#' @rdname as_prediction
#' @export
as_predictions.list = function(x, predict_sets = "test", ...) { # nolint
assert_subset(predict_sets, mlr_reflections$predict_sets)

result = vector("list", length(x))
ii = lengths(x) > 0L
result[ii] = map(x[ii], function(li) {
assert_list(li, "PredictionData")
combined = do.call(c, discard(li[predict_sets], is.null))
if (is.null(combined)) {
list()
li = discard(li[predict_sets], is.null)
if (length(li) == 0L) {
return(list())
}

if (length(li) == 1L) {
combined = li[[1L]]
} else {
as_prediction(combined, check = FALSE)
combined = do.call(c, li)
}
as_prediction(combined, check = FALSE)
})
result
}
2 changes: 1 addition & 1 deletion R/mlr_reflections.R
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ local({
)

### ResampleResult
mlr_reflections$rr_names = c("task", "learner", "resampling", "iteration", "prediction")
mlr_reflections$rr_names = c("task", "learner", "resampling", "iteration")

### Logger
mlr_reflections$loggers = list()
Expand Down
4 changes: 2 additions & 2 deletions inst/testthat/helper_expectations.R
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ expect_benchmark_result = function(bmr) {
expect_resultdata(mlr3misc::get_private(bmr)$.data, TRUE)
testthat::expect_output(print(bmr), "BenchmarkResult")

checkmate::expect_names(names(as.data.table(bmr)), permutation.of = c(mlr3::mlr_reflections$rr_names, "uhash"))
checkmate::expect_names(names(as.data.table(bmr)), permutation.of = c(mlr3::mlr_reflections$rr_names, "prediction", "uhash"))

tab = bmr$tasks
checkmate::expect_data_table(tab, ncols = 3L)
Expand Down Expand Up @@ -679,7 +679,7 @@ expect_benchmark_result = function(bmr) {
checkmate::expect_data_table(tab, ncols = 3L, nrows = bmr$n_resample_results, any.missing = FALSE)
checkmate::expect_character(tab$uhash, any.missing = FALSE)
checkmate::expect_integer(tab$nr, sorted = TRUE, any.missing = FALSE, lower = 1L)
# expect_integer(tab$iters, any.missing = FALSE, lower = 1L)
expect_integer(tab$nr, any.missing = FALSE, lower = 1L)
checkmate::expect_list(tab$resample_result, types = "ResampleResult")

ni = mlr3misc::get_private(bmr)$.data$iterations()
Expand Down
12 changes: 5 additions & 7 deletions man/BenchmarkResult.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 9 additions & 8 deletions man/ResampleResult.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions tests/testthat/test_benchmark.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@ bmr = benchmark(design)

test_that("Basic benchmarking", {
expect_benchmark_result(bmr)
expect_names(names(as.data.table(bmr)), permutation.of = c(mlr_reflections$rr_names, "uhash"))
expect_names(names(as.data.table(bmr)), permutation.of = c(mlr_reflections$rr_names, "uhash", "prediction"))

tab = as.data.table(bmr)
expect_data_table(tab, nrows = 18L, ncols = 6L)
expect_names(names(tab), permutation.of = c("uhash", mlr_reflections$rr_names))
expect_names(names(tab), permutation.of = c("uhash", "prediction", mlr_reflections$rr_names))
measures = list(msr("classif.acc"))

tab = bmr$score(measures, ids = FALSE)
tab = bmr$score(measures, ids = FALSE, predictions = TRUE)
expect_data_table(tab, nrows = 18L, ncols = 7L + length(measures))
expect_names(names(tab), must.include = c("nr", "uhash", mlr_reflections$rr_names, ids(measures)))
expect_list(tab$prediction, "Prediction")
expect_names(names(tab), must.include = c("nr", "uhash", "prediction_test", mlr_reflections$rr_names, ids(measures)))
expect_list(tab$prediction_test, "Prediction")

tab = bmr$tasks
expect_data_table(tab, nrows = 3L, any.missing = FALSE)
Expand Down
9 changes: 4 additions & 5 deletions tests/testthat/test_mlr_reflections.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@ test_that("resampling works", {
rr = resample(task, learner, rsmp("cv", folds = 3))
expect_equal(rr$task_type, "test")

scores = rr$score(msr("classif.ce"))
expect_list(scores$prediction, "Prediction")
scores = rr$score(msr("classif.ce"), predictions = TRUE)
expect_list(scores$prediction_test, "Prediction")
expect_numeric(scores$classif.ce, any.missing = FALSE)
expect_number(rr$aggregate(msr("classif.ce")))

scores = rr$score()
expect_list(scores$prediction, "Prediction")
scores = rr$score(predictions = TRUE)
expect_list(scores$prediction_test, "Prediction")
expect_numeric(scores$classif.ce, any.missing = FALSE)
expect_number(rr$aggregate(msr("classif.ce")))
})
Expand Down Expand Up @@ -104,4 +104,3 @@ test_that("external packages can set column roles", {
resample(task, lrn("classif.rpart"), rsmp("cv", folds = 3))
})
})

18 changes: 9 additions & 9 deletions tests/testthat/test_resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ rr = resample(task, learner, resampling)
test_that("resample", {
expect_resample_result(rr)

scores = rr$score(msr("classif.ce"))
expect_list(scores$prediction, "Prediction")
scores = rr$score(msr("classif.ce"), predictions = TRUE)
expect_list(scores$prediction_test, "Prediction")
expect_numeric(scores$classif.ce, any.missing = FALSE)
expect_number(rr$aggregate(msr("classif.ce")))
learners = rr$learners
Expand All @@ -29,9 +29,9 @@ test_that("empty RR", {

test_that("resample with no or multiple measures", {
for (measures in list(mlr_measures$mget(c("classif.ce", "classif.acc")), list())) {
tab = rr$score(measures, ids = FALSE)
expect_data_table(tab, ncols = length(mlr_reflections$rr_names) + length(measures), nrows = 3L)
expect_set_equal(names(tab), c(mlr_reflections$rr_names, ids(measures)))
tab = rr$score(measures, ids = FALSE, predictions = TRUE)
expect_data_table(tab, ncols = length(mlr_reflections$rr_names) + length(learner$predict_sets) + length(measures), nrows = 3L)
expect_set_equal(names(tab), c(mlr_reflections$rr_names, ids(measures), paste0("prediction_", learner$predict_sets)))
perf = rr$aggregate(measures)
expect_numeric(perf, any.missing = FALSE, len = length(measures), names = "unique")
expect_equal(names(perf), unname(ids(measures)))
Expand Down Expand Up @@ -302,7 +302,7 @@ test_that("internal_valid and train predictions", {
measure_valid = msr("classif.acc")
measure_valid$predict_sets = "internal_valid"
expect_equal(
rr$score(measure_valid, predict_sets = "internal_valid")$classif.acc,
rr$score(measure_valid)$classif.acc,
rr$learners[[1L]]$internal_valid_scores$acc
)

Expand All @@ -312,9 +312,9 @@ test_that("internal_valid and train predictions", {
rr2 = resample(task, learner, rsmp("holdout"))

expect_equal(
rr2$score(measure_valid, predict_sets = "internal_valid")$classif.acc,
rr2$score(msr("classif.acc"), predict_sets = "test")$classif.acc
)
rr2$score(measure_valid)$classif.acc,
rr2$score(msr("classif.acc"))$classif.acc
)
expect_equal(
rr2$predictions("internal_valid")[[1L]]$response,
rr2$predictions("test")[[1L]]$response
Expand Down

0 comments on commit 1e6bbef

Please sign in to comment.