Skip to content

Commit

Permalink
add helper for whether tailor has operations that require training
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch committed Jun 3, 2024
1 parent 75a8f60 commit b6deda5
Show file tree
Hide file tree
Showing 10 changed files with 65 additions and 14 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ export(fit)
export(required_pkgs)
export(tailor)
export(tailor_fully_trained)
export(tailor_requires_fit)
export(tidy)
export(tunable)
export(tune_args)
Expand Down
6 changes: 4 additions & 2 deletions R/adjust-equivocal-zone.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ adjust_equivocal_zone <- function(x, value = 0.1, threshold = 1 / 2) {
outputs = "class",
arguments = list(value = value, threshold = threshold),
results = list(),
trained = FALSE
trained = FALSE,
requires_fit = FALSE
)

new_tailor(
Expand Down Expand Up @@ -75,7 +76,8 @@ fit.equivocal_zone <- function(object, data, tailor = NULL, ...) {
outputs = object$outputs,
arguments = object$arguments,
results = list(),
trained = TRUE
trained = TRUE,
requires_fit = object$requires_fit
)
}

Expand Down
6 changes: 4 additions & 2 deletions R/adjust-numeric-calibration.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ adjust_numeric_calibration <- function(x, method = NULL) {
outputs = "numeric",
arguments = list(method = method),
results = list(),
trained = FALSE
trained = TRUE,
requires_fit = TRUE
)

new_tailor(
Expand Down Expand Up @@ -86,7 +87,8 @@ fit.numeric_calibration <- function(object, data, tailor = NULL, ...) {
outputs = object$outputs,
arguments = object$arguments,
results = list(fit = fit),
trained = TRUE
trained = TRUE,
requires_fit = object$requires_fit
)
}

Expand Down
6 changes: 4 additions & 2 deletions R/adjust-numeric-range.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ adjust_numeric_range <- function(x, lower_limit = -Inf, upper_limit = Inf) {
outputs = "numeric",
arguments = list(lower_limit = lower_limit, upper_limit = upper_limit),
results = list(),
trained = FALSE
trained = FALSE,
requires_fit = FALSE
)

new_tailor(
Expand Down Expand Up @@ -64,7 +65,8 @@ fit.numeric_range <- function(object, data, tailor = NULL, ...) {
outputs = object$outputs,
arguments = object$arguments,
results = list(),
trained = TRUE
trained = TRUE,
requires_fit = object$requires_fit
)
}

Expand Down
8 changes: 6 additions & 2 deletions R/adjust-predictions-custom.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ adjust_predictions_custom <- function(x, ..., .pkgs = character(0)) {
outputs = "everything",
arguments = list(commands = cmds, pkgs = .pkgs),
results = list(),
trained = FALSE
trained = FALSE,
# todo: should there be a user interface to tell tailor whether this
# adjustment requires fit?
requires_fit = FALSE
)

new_tailor(
Expand All @@ -62,7 +65,8 @@ fit.predictions_custom <- function(object, data, tailor = NULL, ...) {
outputs = object$outputs,
arguments = object$arguments,
results = list(),
trained = TRUE
trained = TRUE,
requires_fit = object$requires_fit
)
}

Expand Down
6 changes: 4 additions & 2 deletions R/adjust-probability-calibration.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ adjust_probability_calibration <- function(x, method = NULL) {
outputs = "probability_class",
arguments = list(method = method),
results = list(),
trained = FALSE
trained = FALSE,
requires_fit = TRUE
)

new_tailor(
Expand Down Expand Up @@ -67,7 +68,8 @@ fit.probability_calibration <- function(object, data, tailor = NULL, ...) {
outputs = object$outputs,
arguments = object$arguments,
results = list(fit = fit),
trained = TRUE
trained = TRUE,
requires_fit = object$requires_fit
)
}

Expand Down
6 changes: 4 additions & 2 deletions R/adjust-probability-threshold.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ adjust_probability_threshold <- function(x, threshold = 0.5) {
outputs = "class",
arguments = list(threshold = threshold),
results = list(),
trained = FALSE
trained = FALSE,
requires_fit = FALSE
)

new_tailor(
Expand Down Expand Up @@ -71,7 +72,8 @@ fit.probability_threshold <- function(object, data, tailor = NULL, ...) {
outputs = object$outputs,
arguments = object$arguments,
results = list(),
trained = TRUE
trained = TRUE,
requires_fit = object$requires_fit
)
}

Expand Down
16 changes: 14 additions & 2 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ input_vals <- c("numeric", "probability", "class", "everything")
output_vals <- c("numeric", "probability_class", "class", "everything")

new_operation <- function(cls, inputs, outputs, arguments, results = list(),
trained, ...) {
trained, requires_fit, ...) {
inputs <- arg_match0(inputs, input_vals)
outputs <- arg_match0(outputs, output_vals)

Expand All @@ -46,7 +46,8 @@ new_operation <- function(cls, inputs, outputs, arguments, results = list(),
outputs = outputs,
arguments = arguments,
results = results,
trained = trained
trained = trained,
requires_fit = requires_fit
)
class(res) <- c(cls, "operation")
res
Expand All @@ -72,6 +73,17 @@ tailor_operation_trained <- function(x) {
isTRUE(x$trained)
}

#' @export
#' @keywords internal
#' @rdname tailor-internals
tailor_requires_fit <- function(x) {
any(purrr::map_lgl(x$operations, tailor_operation_requires_fit))
}

tailor_operation_requires_fit <- function(x) {
isTRUE(x$requires_fit)
}

# ad-hoc checking --------------------------------------------------------------
check_tailor <- function(x, calibration_type = NULL, call = caller_env(), arg = caller_arg(x)) {
if (!is_tailor(x)) {
Expand Down
3 changes: 3 additions & 0 deletions man/tailor-internals.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 21 additions & 0 deletions tests/testthat/test-utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,24 @@ test_that("tailor_fully_trained works", {
)
)
})


test_that("tailor_requires_fit works", {
expect_false(tailor_requires_fit(tailor()))
expect_false(
tailor_requires_fit(tailor() %>% adjust_probability_threshold(.5))
)
expect_true(
tailor_requires_fit(
tailor() %>%
adjust_probability_calibration("logistic")
)
)
expect_true(
tailor_requires_fit(
tailor() %>%
adjust_probability_calibration("logistic") %>%
adjust_probability_threshold(.5)
)
)
})

0 comments on commit b6deda5

Please sign in to comment.