From b6deda5be3b1702b8b7b9863b60c6bb61dab1d2a Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Mon, 3 Jun 2024 09:58:09 -0500 Subject: [PATCH] add helper for whether tailor has operations that require training --- NAMESPACE | 1 + R/adjust-equivocal-zone.R | 6 ++++-- R/adjust-numeric-calibration.R | 6 ++++-- R/adjust-numeric-range.R | 6 ++++-- R/adjust-predictions-custom.R | 8 ++++++-- R/adjust-probability-calibration.R | 6 ++++-- R/adjust-probability-threshold.R | 6 ++++-- R/utils.R | 16 ++++++++++++++-- man/tailor-internals.Rd | 3 +++ tests/testthat/test-utils.R | 21 +++++++++++++++++++++ 10 files changed, 65 insertions(+), 14 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index c39a2fc..d0e9232 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -46,6 +46,7 @@ export(fit) export(required_pkgs) export(tailor) export(tailor_fully_trained) +export(tailor_requires_fit) export(tidy) export(tunable) export(tune_args) diff --git a/R/adjust-equivocal-zone.R b/R/adjust-equivocal-zone.R index bcc4508..6c6be5c 100644 --- a/R/adjust-equivocal-zone.R +++ b/R/adjust-equivocal-zone.R @@ -39,7 +39,8 @@ adjust_equivocal_zone <- function(x, value = 0.1, threshold = 1 / 2) { outputs = "class", arguments = list(value = value, threshold = threshold), results = list(), - trained = FALSE + trained = FALSE, + requires_fit = FALSE ) new_tailor( @@ -75,7 +76,8 @@ fit.equivocal_zone <- function(object, data, tailor = NULL, ...) { outputs = object$outputs, arguments = object$arguments, results = list(), - trained = TRUE + trained = TRUE, + requires_fit = object$requires_fit ) } diff --git a/R/adjust-numeric-calibration.R b/R/adjust-numeric-calibration.R index aeab2ac..c54cbb5 100644 --- a/R/adjust-numeric-calibration.R +++ b/R/adjust-numeric-calibration.R @@ -45,7 +45,8 @@ adjust_numeric_calibration <- function(x, method = NULL) { outputs = "numeric", arguments = list(method = method), results = list(), - trained = FALSE + trained = TRUE, + requires_fit = TRUE ) new_tailor( @@ -86,7 +87,8 @@ fit.numeric_calibration <- function(object, data, tailor = NULL, ...) { outputs = object$outputs, arguments = object$arguments, results = list(fit = fit), - trained = TRUE + trained = TRUE, + requires_fit = object$requires_fit ) } diff --git a/R/adjust-numeric-range.R b/R/adjust-numeric-range.R index 35691e5..db72849 100644 --- a/R/adjust-numeric-range.R +++ b/R/adjust-numeric-range.R @@ -15,7 +15,8 @@ adjust_numeric_range <- function(x, lower_limit = -Inf, upper_limit = Inf) { outputs = "numeric", arguments = list(lower_limit = lower_limit, upper_limit = upper_limit), results = list(), - trained = FALSE + trained = FALSE, + requires_fit = FALSE ) new_tailor( @@ -64,7 +65,8 @@ fit.numeric_range <- function(object, data, tailor = NULL, ...) { outputs = object$outputs, arguments = object$arguments, results = list(), - trained = TRUE + trained = TRUE, + requires_fit = object$requires_fit ) } diff --git a/R/adjust-predictions-custom.R b/R/adjust-predictions-custom.R index ea76318..18e166a 100644 --- a/R/adjust-predictions-custom.R +++ b/R/adjust-predictions-custom.R @@ -35,7 +35,10 @@ adjust_predictions_custom <- function(x, ..., .pkgs = character(0)) { outputs = "everything", arguments = list(commands = cmds, pkgs = .pkgs), results = list(), - trained = FALSE + trained = FALSE, + # todo: should there be a user interface to tell tailor whether this + # adjustment requires fit? + requires_fit = FALSE ) new_tailor( @@ -62,7 +65,8 @@ fit.predictions_custom <- function(object, data, tailor = NULL, ...) { outputs = object$outputs, arguments = object$arguments, results = list(), - trained = TRUE + trained = TRUE, + requires_fit = object$requires_fit ) } diff --git a/R/adjust-probability-calibration.R b/R/adjust-probability-calibration.R index 428eda4..13d5004 100644 --- a/R/adjust-probability-calibration.R +++ b/R/adjust-probability-calibration.R @@ -24,7 +24,8 @@ adjust_probability_calibration <- function(x, method = NULL) { outputs = "probability_class", arguments = list(method = method), results = list(), - trained = FALSE + trained = FALSE, + requires_fit = TRUE ) new_tailor( @@ -67,7 +68,8 @@ fit.probability_calibration <- function(object, data, tailor = NULL, ...) { outputs = object$outputs, arguments = object$arguments, results = list(fit = fit), - trained = TRUE + trained = TRUE, + requires_fit = object$requires_fit ) } diff --git a/R/adjust-probability-threshold.R b/R/adjust-probability-threshold.R index f224a58..552d02d 100644 --- a/R/adjust-probability-threshold.R +++ b/R/adjust-probability-threshold.R @@ -35,7 +35,8 @@ adjust_probability_threshold <- function(x, threshold = 0.5) { outputs = "class", arguments = list(threshold = threshold), results = list(), - trained = FALSE + trained = FALSE, + requires_fit = FALSE ) new_tailor( @@ -71,7 +72,8 @@ fit.probability_threshold <- function(object, data, tailor = NULL, ...) { outputs = object$outputs, arguments = object$arguments, results = list(), - trained = TRUE + trained = TRUE, + requires_fit = object$requires_fit ) } diff --git a/R/utils.R b/R/utils.R index e467eb8..9c696d4 100644 --- a/R/utils.R +++ b/R/utils.R @@ -34,7 +34,7 @@ input_vals <- c("numeric", "probability", "class", "everything") output_vals <- c("numeric", "probability_class", "class", "everything") new_operation <- function(cls, inputs, outputs, arguments, results = list(), - trained, ...) { + trained, requires_fit, ...) { inputs <- arg_match0(inputs, input_vals) outputs <- arg_match0(outputs, output_vals) @@ -46,7 +46,8 @@ new_operation <- function(cls, inputs, outputs, arguments, results = list(), outputs = outputs, arguments = arguments, results = results, - trained = trained + trained = trained, + requires_fit = requires_fit ) class(res) <- c(cls, "operation") res @@ -72,6 +73,17 @@ tailor_operation_trained <- function(x) { isTRUE(x$trained) } +#' @export +#' @keywords internal +#' @rdname tailor-internals +tailor_requires_fit <- function(x) { + any(purrr::map_lgl(x$operations, tailor_operation_requires_fit)) +} + +tailor_operation_requires_fit <- function(x) { + isTRUE(x$requires_fit) +} + # ad-hoc checking -------------------------------------------------------------- check_tailor <- function(x, calibration_type = NULL, call = caller_env(), arg = caller_arg(x)) { if (!is_tailor(x)) { diff --git a/man/tailor-internals.Rd b/man/tailor-internals.Rd index 5c4e355..c71b60a 100644 --- a/man/tailor-internals.Rd +++ b/man/tailor-internals.Rd @@ -3,9 +3,12 @@ \name{tailor-internals} \alias{tailor-internals} \alias{tailor_fully_trained} +\alias{tailor_requires_fit} \title{Internal tailor functions} \usage{ tailor_fully_trained(x) + +tailor_requires_fit(x) } \description{ Utilities for use in downstream packages. diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R index 88737cc..aec397d 100644 --- a/tests/testthat/test-utils.R +++ b/tests/testthat/test-utils.R @@ -50,3 +50,24 @@ test_that("tailor_fully_trained works", { ) ) }) + + +test_that("tailor_requires_fit works", { + expect_false(tailor_requires_fit(tailor())) + expect_false( + tailor_requires_fit(tailor() %>% adjust_probability_threshold(.5)) + ) + expect_true( + tailor_requires_fit( + tailor() %>% + adjust_probability_calibration("logistic") + ) + ) + expect_true( + tailor_requires_fit( + tailor() %>% + adjust_probability_calibration("logistic") %>% + adjust_probability_threshold(.5) + ) + ) +})