diff --git a/R/adjust-numeric-calibration.R b/R/adjust-numeric-calibration.R index 10a9ec5..19c3176 100644 --- a/R/adjust-numeric-calibration.R +++ b/R/adjust-numeric-calibration.R @@ -1,8 +1,11 @@ #' Re-calibrate numeric predictions #' #' @param x A [container()]. -#' @param calibrator A pre-trained calibration method from the \pkg{probably} -#' package, such as [probably::cal_estimate_linear()]. +#' @param type Character. One of `"linear"`, `"isotonic"`, or +#' `"isotonic_boot"`, corresponding to the function from the \pkg{probably} +#' package [probably::cal_estimate_linear()], +#' [probably::cal_estimate_isotonic()], or +#' [probably::cal_estimate_isotonic_boot()], respectively. #' @examples #' library(modeldata) #' library(probably) @@ -14,27 +17,24 @@ #' #' dat #' -#' # calibrate numeric predictions -#' reg_cal <- cal_estimate_linear(dat, truth = y, estimate = y_pred) -#' #' # specify calibration #' reg_ctr <- #' container(mode = "regression") %>% -#' adjust_numeric_calibration(reg_cal) +#' adjust_numeric_calibration(type = "linear") #' -#' # "train" container +#' # train container #' reg_ctr_trained <- fit(reg_ctr, dat, outcome = y, estimate = y_pred) #' -#' predict(reg_ctr, dat) +#' predict(reg_ctr_trained, dat) #' @export -adjust_numeric_calibration <- function(x, calibrator) { - check_container(x) - check_required(calibrator) - if (!inherits(calibrator, "cal_regression")) { - cli_abort( - "{.arg calibrator} should be a \\ - {.help [ object](probably::cal_estimate_linear)}, \\ - not {.obj_type_friendly {calibrator}}." +adjust_numeric_calibration <- function(x, type = NULL) { + # to-do: add argument specifying `prop` in initial_split + check_container(x, calibration_type = "numeric") + # wait to `check_type()` until `fit()` time + if (!is.null(type)) { + arg_match0( + type, + c("linear", "isotonic", "isotonic_boot") ) } @@ -43,7 +43,7 @@ adjust_numeric_calibration <- function(x, calibrator) { "numeric_calibration", inputs = "numeric", outputs = "numeric", - arguments = list(calibrator = calibrator), + arguments = list(type = type), results = list(), trained = FALSE ) @@ -67,19 +67,33 @@ print.numeric_calibration <- function(x, ...) { #' @export fit.numeric_calibration <- function(object, data, container = NULL, ...) { + type <- check_type(object$type, container$type) + # todo: adjust_numeric_calibration() should take arguments to pass to + # cal_estimate_* via dots + fit <- + eval_bare( + call2( + paste0("cal_estimate_", type), + .data = data, + truth = container$columns$outcome, + estimate = container$columns$estimate, + .ns = "probably" + ) + ) + new_operation( class(object), inputs = object$inputs, outputs = object$outputs, arguments = object$arguments, - results = list(), + results = list(fit = fit), trained = TRUE ) } #' @export predict.numeric_calibration <- function(object, new_data, container, ...) { - probably::cal_apply(new_data, object$argument$calibrator) + probably::cal_apply(new_data, object$results$fit) } # todo probably needs required_pkgs methods for cal objects diff --git a/R/adjust-probability-calibration.R b/R/adjust-probability-calibration.R index 9808285..6206d53 100644 --- a/R/adjust-probability-calibration.R +++ b/R/adjust-probability-calibration.R @@ -1,18 +1,19 @@ #' Re-calibrate classification probability predictions #' #' @param x A [container()]. -#' @param calibrator A pre-trained calibration method from the \pkg{probably} -#' package, such as [probably::cal_estimate_logistic()]. +#' @param type Character. One of `"logistic"`, `"multinomial"`, +#' `"beta"`, `"isotonic"`, or `"isotonic_boot"`, corresponding to the +#' function from the \pkg{probably} package [probably::cal_estimate_logistic()], +#' [probably::cal_estimate_multinomial()], etc., respectively. #' @export -adjust_probability_calibration <- function(x, calibrator) { - check_container(x) - cls <- c("cal_binary", "cal_multinomial") - check_required(calibrator) - if (!inherits_any(calibrator, cls)) { - cli_abort( - "{.arg calibrator} should be a \\ - {.help [ or object](probably::cal_estimate_logistic)}, \\ - not {.obj_type_friendly {calibrator}}." +adjust_probability_calibration <- function(x, type = NULL) { + # to-do: add argument specifying `prop` in initial_split + check_container(x, calibration_type = "probability") + # wait to `check_type()` until `fit()` time + if (!is.null(type)) { + arg_match( + type, + c("logistic", "multinomial", "beta", "isotonic", "isotonic_boot") ) } @@ -21,7 +22,7 @@ adjust_probability_calibration <- function(x, calibrator) { "probability_calibration", inputs = "probability", outputs = "probability_class", - arguments = list(calibrator = calibrator), + arguments = list(type = type), results = list(), trained = FALSE ) @@ -45,19 +46,35 @@ print.probability_calibration <- function(x, ...) { #' @export fit.probability_calibration <- function(object, data, container = NULL, ...) { + type <- check_type(object$type, container$type) + # todo: adjust_probability_calibration() should take arguments to pass to + # cal_estimate_* via dots + # to-do: add argument specifying `prop` in initial_split + fit <- + eval_bare( + call2( + paste0("cal_estimate_", type), + .data = data, + # todo: make getters for the entries in `columns` + truth = container$columns$outcome, + estimate = container$columns$estimate, + .ns = "probably" + ) + ) + new_operation( class(object), inputs = object$inputs, outputs = object$outputs, arguments = object$arguments, - results = list(), + results = list(fit = fit), trained = TRUE ) } #' @export predict.probability_calibration <- function(object, new_data, container, ...) { - probably::cal_apply(new_data, object$argument$calibrator) + probably::cal_apply(new_data, object$results$fit) } # todo probably needs required_pkgs methods for cal objects diff --git a/R/container.R b/R/container.R index e199a7d..a07e852 100644 --- a/R/container.R +++ b/R/container.R @@ -130,7 +130,7 @@ fit.container <- function(object, .data, outcome, estimate, probabilities = c(), num_oper <- length(object$operations) for (op in seq_len(num_oper)) { - object$operations[[op]] <- fit(object$operations[[op]], data, object) + object$operations[[op]] <- fit(object$operations[[op]], .data, object) .data <- predict(object$operations[[op]], .data, object) } diff --git a/R/utils.R b/R/utils.R index e90ed46..b6bcef7 100644 --- a/R/utils.R +++ b/R/utils.R @@ -49,14 +49,95 @@ is_container <- function(x) { } # ad-hoc checking -------------------------------------------------------------- -check_container <- function(x, call = caller_env(), arg = caller_arg(x)) { +check_container <- function(x, calibration_type = NULL, call = caller_env(), arg = caller_arg(x)) { if (!is_container(x)) { - cli::cli_abort( + cli_abort( "{.arg {arg}} should be a {.help [{.cls container}](container::container)}, \\ not {.obj_type_friendly {x}}.", call = call ) } + # check that the type of calibration ("numeric" or "probability") is + # compatible with the container type + if (!is.null(calibration_type)) { + container_type <- x$type + switch( + container_type, + regression = + check_calibration_type(calibration_type, "numeric", container_type, call = call), + binary = , multinomial = + check_calibration_type(calibration_type, "probability", container_type, call = call) + ) + } + invisible() } + +check_calibration_type <- function(calibration_type, calibration_type_expected, + container_type, call) { + if (!identical(calibration_type, calibration_type_expected)) { + cli_abort( + "A {.field {container_type}} container is incompatible with the operation \\ + {.fun {paste0('adjust_', calibration_type, '_calibration')}}.", + call = call + ) + } +} + +types_regression <- c("linear", "isotonic", "isotonic_boot") +types_binary <- c("logistic", "beta", "isotonic", "isotonic_boot") +types_multiclass <- c("multinomial", "beta", "isotonic", "isotonic_boot") +# a check function to be called when a container is being `fit()`ted. +# by the time a container is fitted, we have: +# * `adjust_type`, the `type` argument passed to an `adjust_*` function +# * this argument has already been checked to agree with the kind of +# `adjust_*()` function via `arg_match0()`. +# * `container_type`, the `type` argument either specified in `container()` +# or inferred in `fit.container()`. +check_type <- function(adjust_type, + container_type, + arg = caller_arg(adjust_type), + call = caller_env()) { + # if no `adjust_type` was supplied, infer a reasonable one based on the + # `container_type` + if (is.null(adjust_type)) { + switch( + container_type, + regression = return("linear"), + binary = return("logistic"), + multiclass = return("multinomial") + ) + } + + switch( + container_type, + regression = arg_match0( + adjust_type, + types_regression, + arg_nm = arg, + error_call = call + ), + binary = arg_match0( + adjust_type, + types_binary, + arg_nm = arg, + error_call = call + ), + multiclass = arg_match0( + adjust_type, + types_multiclass, + arg_nm = arg, + error_call = call + ), + arg_match0( + adjust_type, + unique(c(types_regression, types_binary, types_multiclass)), + arg_nm = arg, + error_call = call + ) + ) + + adjust_type +} + diff --git a/man/adjust_numeric_calibration.Rd b/man/adjust_numeric_calibration.Rd index f8e6315..0650ac0 100644 --- a/man/adjust_numeric_calibration.Rd +++ b/man/adjust_numeric_calibration.Rd @@ -4,13 +4,16 @@ \alias{adjust_numeric_calibration} \title{Re-calibrate numeric predictions} \usage{ -adjust_numeric_calibration(x, calibrator) +adjust_numeric_calibration(x, type = NULL) } \arguments{ \item{x}{A \code{\link[=container]{container()}}.} -\item{calibrator}{A pre-trained calibration method from the \pkg{probably} -package, such as \code{\link[probably:cal_estimate_linear]{probably::cal_estimate_linear()}}.} +\item{type}{Character. One of \code{"linear"}, \code{"isotonic"}, or +\code{"isotonic_boot"}, corresponding to the function from the \pkg{probably} +package \code{\link[probably:cal_estimate_linear]{probably::cal_estimate_linear()}}, +\code{\link[probably:cal_estimate_isotonic]{probably::cal_estimate_isotonic()}}, or +\code{\link[probably:cal_estimate_isotonic_boot]{probably::cal_estimate_isotonic_boot()}}, respectively.} } \description{ Re-calibrate numeric predictions @@ -26,16 +29,13 @@ dat <- tibble(y = rnorm(100), y_pred = y/2 + rnorm(100)) dat -# calibrate numeric predictions -reg_cal <- cal_estimate_linear(dat, truth = y, estimate = y_pred) - # specify calibration reg_ctr <- container(mode = "regression") \%>\% - adjust_numeric_calibration(reg_cal) + adjust_numeric_calibration(type = "linear") -# "train" container +# train container reg_ctr_trained <- fit(reg_ctr, dat, outcome = y, estimate = y_pred) -predict(reg_ctr, dat) +predict(reg_ctr_trained, dat) } diff --git a/man/adjust_probability_calibration.Rd b/man/adjust_probability_calibration.Rd index 65e0392..3bd2adf 100644 --- a/man/adjust_probability_calibration.Rd +++ b/man/adjust_probability_calibration.Rd @@ -4,13 +4,15 @@ \alias{adjust_probability_calibration} \title{Re-calibrate classification probability predictions} \usage{ -adjust_probability_calibration(x, calibrator) +adjust_probability_calibration(x, type = NULL) } \arguments{ \item{x}{A \code{\link[=container]{container()}}.} -\item{calibrator}{A pre-trained calibration method from the \pkg{probably} -package, such as \code{\link[probably:cal_estimate_logistic]{probably::cal_estimate_logistic()}}.} +\item{type}{Character. One of \code{"logistic"}, \code{"multinomial"}, +\code{"beta"}, \code{"isotonic"}, or \code{"isotonic_boot"}, corresponding to the +function from the \pkg{probably} package \code{\link[probably:cal_estimate_logistic]{probably::cal_estimate_logistic()}}, +\code{\link[probably:cal_estimate_multinomial]{probably::cal_estimate_multinomial()}}, etc., respectively.} } \description{ Re-calibrate classification probability predictions diff --git a/tests/testthat/_snaps/adjust-numeric-calibration.md b/tests/testthat/_snaps/adjust-numeric-calibration.md index dd4d98c..0d904a3 100644 --- a/tests/testthat/_snaps/adjust-numeric-calibration.md +++ b/tests/testthat/_snaps/adjust-numeric-calibration.md @@ -1,35 +1,36 @@ # adjustment printing Code - ctr_reg %>% adjust_numeric_calibration(dummy_reg_cal) + ctr_reg %>% adjust_numeric_calibration() Message -- Container ------------------------------------------------------------------- - A postprocessor with 1 operation: + A regression postprocessor with 1 operation: * Re-calibrate numeric predictions. # errors informatively with bad input Code - adjust_numeric_calibration(ctr_reg) + adjust_numeric_calibration(ctr_reg, "boop") Condition Error in `adjust_numeric_calibration()`: - ! `calibrator` is absent but must be supplied. + ! `type` must be one of "linear", "isotonic", or "isotonic_boot", not "boop". --- Code - adjust_numeric_calibration(ctr_reg, "boop") + container("classification", "binary") %>% adjust_numeric_calibration("linear") Condition Error in `adjust_numeric_calibration()`: - ! `calibrator` should be a object (`?probably::cal_estimate_linear()`), not a string. + ! A binary container is incompatible with the operation `adjust_numeric_calibration()`. --- Code - adjust_numeric_calibration(ctr_cls, dummy_cls_cal) + container("regression", "regression") %>% adjust_numeric_calibration("binary") Condition Error in `adjust_numeric_calibration()`: - ! `calibrator` should be a object (`?probably::cal_estimate_linear()`), not a object. + ! `type` must be one of "linear", "isotonic", or "isotonic_boot", not "binary". + i Did you mean "linear"? diff --git a/tests/testthat/_snaps/adjust-numeric-range.md b/tests/testthat/_snaps/adjust-numeric-range.md index afa9537..b3df879 100644 --- a/tests/testthat/_snaps/adjust-numeric-range.md +++ b/tests/testthat/_snaps/adjust-numeric-range.md @@ -5,7 +5,7 @@ Message -- Container ------------------------------------------------------------------- - A postprocessor with 1 operation: + A regression postprocessor with 1 operation: * Constrain numeric predictions to be between [-Inf, Inf]. @@ -16,7 +16,7 @@ Message -- Container ------------------------------------------------------------------- - A postprocessor with 1 operation: + A regression postprocessor with 1 operation: * Constrain numeric predictions to be between [?, Inf]. @@ -27,7 +27,7 @@ Message -- Container ------------------------------------------------------------------- - A postprocessor with 1 operation: + A regression postprocessor with 1 operation: * Constrain numeric predictions to be between [-1, ?]. @@ -38,7 +38,7 @@ Message -- Container ------------------------------------------------------------------- - A postprocessor with 1 operation: + A regression postprocessor with 1 operation: * Constrain numeric predictions to be between [?, 1]. diff --git a/tests/testthat/_snaps/adjust-probability-calibration.md b/tests/testthat/_snaps/adjust-probability-calibration.md index 52a037e..2fefbea 100644 --- a/tests/testthat/_snaps/adjust-probability-calibration.md +++ b/tests/testthat/_snaps/adjust-probability-calibration.md @@ -1,7 +1,7 @@ # adjustment printing Code - ctr_cls %>% adjust_probability_calibration(dummy_cls_cal) + ctr_cls %>% adjust_probability_calibration("logistic") Message -- Container ------------------------------------------------------------------- @@ -12,24 +12,26 @@ # errors informatively with bad input Code - adjust_probability_calibration(ctr_cls) + adjust_probability_calibration(ctr_cls, "boop") Condition Error in `adjust_probability_calibration()`: - ! `calibrator` is absent but must be supplied. + ! `type` must be one of "logistic", "multinomial", "beta", "isotonic", or "isotonic_boot", not "boop". --- Code - adjust_probability_calibration(ctr_cls, "boop") + container("regression", "regression") %>% adjust_probability_calibration( + "binary") Condition Error in `adjust_probability_calibration()`: - ! `calibrator` should be a or object (`?probably::cal_estimate_logistic()`), not a string. + ! A regression container is incompatible with the operation `adjust_probability_calibration()`. --- Code - adjust_probability_calibration(ctr_cls, dummy_reg_cal) + container("classification", "binary") %>% adjust_probability_calibration( + "linear") Condition Error in `adjust_probability_calibration()`: - ! `calibrator` should be a or object (`?probably::cal_estimate_logistic()`), not a object. + ! `type` must be one of "logistic", "multinomial", "beta", "isotonic", or "isotonic_boot", not "linear". diff --git a/tests/testthat/_snaps/validation-rules.md b/tests/testthat/_snaps/validation-rules.md index a4ae358..df42bda 100644 --- a/tests/testthat/_snaps/validation-rules.md +++ b/tests/testthat/_snaps/validation-rules.md @@ -2,8 +2,7 @@ Code container(mode = "regression") %>% adjust_numeric_range(lower_limit = 2) %>% - adjust_numeric_calibration(dummy_reg_cal) %>% adjust_predictions_custom( - squared = .pred^2) + adjust_numeric_calibration() %>% adjust_predictions_custom(squared = .pred^2) Condition Error in `adjust_numeric_calibration()`: ! Calibration should come before other operations. @@ -12,7 +11,7 @@ Code container(mode = "classification") %>% adjust_probability_threshold(threshold = 0.4) %>% - adjust_probability_calibration(dummy_cls_cal) + adjust_probability_calibration() Condition Error in `adjust_probability_calibration()`: ! Operations that change the hard class predictions must come after operations that update the class probability estimates. @@ -22,7 +21,7 @@ Code container(mode = "classification") %>% adjust_predictions_custom(veg = "potato") %>% adjust_probability_threshold(threshold = 0.4) %>% - adjust_probability_calibration(dummy_cls_cal) + adjust_probability_calibration() Condition Error in `adjust_probability_calibration()`: ! Operations that change the hard class predictions must come after operations that update the class probability estimates. @@ -33,7 +32,7 @@ container(mode = "classification") %>% adjust_predictions_custom(veg = "potato") %>% adjust_probability_threshold(threshold = 0.4) %>% adjust_probability_threshold(threshold = 0.5) %>% - adjust_probability_calibration(dummy_cls_cal) + adjust_probability_calibration() Condition Error in `adjust_probability_threshold()`: ! Operations cannot be duplicated: "probability_threshold" diff --git a/tests/testthat/test-adjust-numeric-calibration.R b/tests/testthat/test-adjust-numeric-calibration.R index b67b717..d664653 100644 --- a/tests/testthat/test-adjust-numeric-calibration.R +++ b/tests/testthat/test-adjust-numeric-calibration.R @@ -1,13 +1,22 @@ test_that("adjustment printing", { - dummy_reg_cal <- structure(list(), class = "cal_regression") - expect_snapshot(ctr_reg %>% adjust_numeric_calibration(dummy_reg_cal)) + expect_snapshot(ctr_reg %>% adjust_numeric_calibration()) }) test_that("errors informatively with bad input", { # check for `adjust_numeric_calibration(container)` is in `utils.R` tests - expect_snapshot(error = TRUE, adjust_numeric_calibration(ctr_reg)) expect_snapshot(error = TRUE, adjust_numeric_calibration(ctr_reg, "boop")) - dummy_cls_cal <- structure(list(), class = "cal_binary") - expect_snapshot(error = TRUE, adjust_numeric_calibration(ctr_cls, dummy_cls_cal)) + expect_snapshot( + error = TRUE, + container("classification", "binary") %>% adjust_numeric_calibration("linear") + ) + expect_snapshot( + error = TRUE, + container("regression", "regression") %>% adjust_numeric_calibration("binary") + ) + # todo: this should error, mode is incompatible even though type is fine + # expect_snapshot(error = TRUE, adjust_numeric_calibration(ctr_cls, "linear")) + + expect_no_condition(adjust_numeric_calibration(ctr_reg)) + expect_no_condition(adjust_numeric_calibration(ctr_reg, "linear")) }) diff --git a/tests/testthat/test-adjust-probability-calibration.R b/tests/testthat/test-adjust-probability-calibration.R index 193ef1d..9e68d51 100644 --- a/tests/testthat/test-adjust-probability-calibration.R +++ b/tests/testthat/test-adjust-probability-calibration.R @@ -1,16 +1,22 @@ test_that("adjustment printing", { - dummy_cls_cal <- structure(list(), class = "cal_binary") - expect_snapshot(ctr_cls %>% adjust_probability_calibration(dummy_cls_cal)) + expect_snapshot(ctr_cls %>% adjust_probability_calibration("logistic")) }) test_that("errors informatively with bad input", { # check for `adjust_probably_calibration(container)` is in `utils.R` tests - expect_snapshot(error = TRUE, adjust_probability_calibration(ctr_cls)) expect_snapshot(error = TRUE, adjust_probability_calibration(ctr_cls, "boop")) - dummy_reg_cal <- structure(list(), class = "cal_regression") expect_snapshot( error = TRUE, - adjust_probability_calibration(ctr_cls, dummy_reg_cal) + container("regression", "regression") %>% adjust_probability_calibration("binary") ) + expect_snapshot( + error = TRUE, + container("classification", "binary") %>% adjust_probability_calibration("linear") + ) + # todo: this should error, mode is incompatible even though type is fine + # expect_snapshot(error = TRUE, adjust_numeric_calibration(ctr_cls, "linear")) + + expect_no_condition(adjust_numeric_calibration(ctr_reg)) + expect_no_condition(adjust_numeric_calibration(ctr_reg, "linear")) }) diff --git a/tests/testthat/test-validation-rules.R b/tests/testthat/test-validation-rules.R index c86cb1b..421130a 100644 --- a/tests/testthat/test-validation-rules.R +++ b/tests/testthat/test-validation-rules.R @@ -1,11 +1,8 @@ test_that("validation of operations (regression)", { - dummy_reg_cal <- list() - class(dummy_reg_cal) <- "cal_regression" - expect_silent( reg_ctr <- container(mode = "regression") %>% - adjust_numeric_calibration(dummy_reg_cal) %>% + adjust_numeric_calibration() %>% adjust_numeric_range(lower_limit = 2) %>% adjust_predictions_custom(squared = .pred^2) ) @@ -13,7 +10,7 @@ test_that("validation of operations (regression)", { expect_snapshot( container(mode = "regression") %>% adjust_numeric_range(lower_limit = 2) %>% - adjust_numeric_calibration(dummy_reg_cal) %>% + adjust_numeric_calibration() %>% adjust_predictions_custom(squared = .pred^2), error = TRUE ) @@ -24,19 +21,17 @@ test_that("validation of operations (regression)", { reg_ctr <- container(mode = "regression") %>% adjust_predictions_custom(squared = .pred^2) %>% - adjust_numeric_calibration(dummy_reg_cal) %>% + adjust_numeric_calibration() %>% adjust_numeric_range(lower_limit = 2) ) }) test_that("validation of operations (classification)", { - dummy_cls_cal <- list() - class(dummy_cls_cal) <- "cal_binary" - expect_silent( cls_ctr_1 <- container(mode = "classification") %>% - adjust_probability_calibration(dummy_cls_cal) %>% + # to-do: should be able to supply no `type` argument here + adjust_probability_calibration("logistic") %>% adjust_probability_threshold(threshold = .4) ) @@ -45,14 +40,14 @@ test_that("validation of operations (classification)", { container(mode = "classification") %>% adjust_predictions_custom(starch = "potato") %>% adjust_predictions_custom(veg = "green beans") %>% - adjust_probability_calibration(dummy_cls_cal) %>% + adjust_probability_calibration("logistic") %>% adjust_probability_threshold(threshold = .4) ) expect_snapshot( container(mode = "classification") %>% adjust_probability_threshold(threshold = .4) %>% - adjust_probability_calibration(dummy_cls_cal), + adjust_probability_calibration(), error = TRUE ) @@ -60,7 +55,7 @@ test_that("validation of operations (classification)", { container(mode = "classification") %>% adjust_predictions_custom(veg = "potato") %>% adjust_probability_threshold(threshold = .4) %>% - adjust_probability_calibration(dummy_cls_cal), + adjust_probability_calibration(), error = TRUE ) @@ -69,7 +64,7 @@ test_that("validation of operations (classification)", { adjust_predictions_custom(veg = "potato") %>% adjust_probability_threshold(threshold = .4) %>% adjust_probability_threshold(threshold = .5) %>% - adjust_probability_calibration(dummy_cls_cal), + adjust_probability_calibration(), error = TRUE )