Skip to content

Commit

Permalink
fix #452
Browse files Browse the repository at this point in the history
  • Loading branch information
mllg committed Mar 9, 2020
1 parent a591d70 commit 2670210
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 5 deletions.
17 changes: 13 additions & 4 deletions R/PredictionClassif.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@
#' Anyway, the class label with maximum ratio is selected.
#' In case of ties in the ratio, one of the tied class labels is selected randomly.
#'
#' Note that there are the following edge cases for threshold equal to `0` which are handled specially:
#' 1. With threshold 0 the resulting ratio gets `Inf` and thus gets always selected.
#' If there are multiple ratios with value `Inf`, one is selected according to `ties_method` (randomly per default).
#' 2. If additionally the predicted probability is also 0, the ratio `0/0` results in `NaN` values.
#' These are simply replaced by `0` and thus will never get selected.
#'
#' @family Prediction
#' @export
#' @examples
Expand Down Expand Up @@ -125,12 +131,15 @@ PredictionClassif = R6Class("PredictionClassif", inherit = Prediction,
#' See the section on thresholding for more information.
#'
#' @param threshold (`numeric()`).
#' @param ties_method (`character(1)`)\cr
#' One of `"random"`, `"first"` or `"last"` (c.f. [max.col()]) to determine how to deal with
#' tied probabilities.
#'
#' @return
#' Returns the object itself, but modified **by reference**.
#' You need to explicitly `$clone()` the object beforehand if you want to keeps
#' the object in its previous state.
set_threshold = function(threshold) {
set_threshold = function(threshold, ties_method = "random") {
if (!is.matrix(self$data$prob)) {
stopf("Cannot set threshold, no probabilities available")
}
Expand All @@ -148,11 +157,11 @@ PredictionClassif = R6Class("PredictionClassif", inherit = Prediction,
threshold = threshold[lvls] # reorder thresh so it is in the same order as levels

# multiply all rows by threshold, then get index of max element per row
w = ifelse(threshold > 0, 1 / threshold, Inf)
prob = self$data$prob %*% diag(w)
prob = self$data$prob %*% diag(1 / threshold) # can generate Inf for threshold 0
prob[is.na(prob)] = 0 # NaN results from 0 * Inf, replace with 0, c.f. #452
}

ind = max.col(prob, ties.method = "random")
ind = max.col(prob, ties.method = ties_method)
self$data$tab$response = factor(lvls[ind], levels = lvls)
invisible(self)
}
Expand Down
14 changes: 13 additions & 1 deletion man/PredictionClassif.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 26 additions & 0 deletions tests/testthat/test_PredictionClassif.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,32 @@ test_that("setting threshold multiclass", {
expect_equal(as.character(unique(x$response)), task$class_names[1L])
})

test_that("setting threshold edge cases (#452)", {
learner = lrn("classif.rpart", predict_type = "prob")
t = tsk("iris")
prd = learner$train(t)$predict(t$clone()$filter(c(1, 51, 101)))

prd$set_threshold(c(setosa = 0, versicolor = 0, virginica = 0), ties_method = "first")
expect_equal(as.character(prd$response), c("setosa", "versicolor", "versicolor"))
prd$set_threshold(c(setosa = 0, versicolor = 0, virginica = 0), ties_method = "last")
expect_equal(as.character(prd$response), c("setosa", "virginica", "virginica"))

prd$set_threshold(c(setosa = 1, versicolor = 0, virginica = 0), ties_method = "first")
expect_equal(as.character(prd$response), c("setosa", "versicolor", "versicolor"))
prd$set_threshold(c(setosa = 1, versicolor = 0, virginica = 0), ties_method = "last")
expect_equal(as.character(prd$response), c("setosa", "virginica", "virginica"))

prd$set_threshold(c(setosa = 0, versicolor = 1, virginica = 0), ties_method = "first")
expect_equal(as.character(prd$response), c("setosa", "virginica", "virginica"))
prd$set_threshold(c(setosa = 0, versicolor = 1, virginica = 0), ties_method = "last")
expect_equal(as.character(prd$response), c("setosa", "virginica", "virginica"))

prd$set_threshold(c(setosa = 0, versicolor = 0, virginica = 1), ties_method = "first")
expect_equal(as.character(prd$response), c("setosa", "versicolor", "versicolor"))
prd$set_threshold(c(setosa = 0, versicolor = 0, virginica = 1), ties_method = "last")
expect_equal(as.character(prd$response), c("setosa", "versicolor", "versicolor"))
})

test_that("confusion", {
task = tsk("iris")
lrn = lrn("classif.featureless")
Expand Down

0 comments on commit 2670210

Please sign in to comment.