Skip to content

Commit

Permalink
implement tune_args() and tunable()
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch committed Oct 16, 2024
1 parent 317a4db commit 95e28be
Show file tree
Hide file tree
Showing 16 changed files with 206 additions and 9 deletions.
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,16 @@ S3method(required_pkgs,numeric_range)
S3method(required_pkgs,predictions_custom)
S3method(required_pkgs,probability_calibration)
S3method(required_pkgs,probability_threshold)
S3method(tunable,adjustment)
S3method(tunable,equivocal_zone)
S3method(tunable,numeric_calibration)
S3method(tunable,numeric_range)
S3method(tunable,predictions_custom)
S3method(tunable,probability_calibration)
S3method(tunable,probability_threshold)
S3method(tunable,tailor)
S3method(tune_args,adjustment)
S3method(tune_args,tailor)
export("%>%")
export(adjust_equivocal_zone)
export(adjust_numeric_calibration)
Expand Down
1 change: 0 additions & 1 deletion R/adjust-equivocal-zone.R
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,5 @@ tunable.equivocal_zone <- function(x, ...) {
}

# todo missing methods:
# todo tune_args
# todo tidy
# todo extract_parameter_set_dials
1 change: 0 additions & 1 deletion R/adjust-numeric-calibration.R
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,5 @@ tunable.numeric_calibration <- function(x, ...) {
}

# todo missing methods:
# todo tune_args
# todo tidy
# todo extract_parameter_set_dials
1 change: 0 additions & 1 deletion R/adjust-numeric-range.R
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,5 @@ tunable.numeric_range <- function(x, ...) {
}

# todo missing methods:
# todo tune_args
# todo tidy
# todo extract_parameter_set_dials
1 change: 0 additions & 1 deletion R/adjust-predictions-custom.R
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,5 @@ tunable.predictions_custom <- function(x, ...) {
}

# todo missing methods:
# todo tune_args
# todo tidy
# todo extract_parameter_set_dials
1 change: 0 additions & 1 deletion R/adjust-probability-calibration.R
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,5 @@ tunable.probability_calibration <- function(x, ...) {
}

# todo missing methods:
# todo tune_args
# todo tidy
# todo extract_parameter_set_dials
1 change: 0 additions & 1 deletion R/adjust-probability-threshold.R
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,5 @@ tunable.probability_threshold <- function(x, ...) {
}

# todo missing methods:
# todo tune_args
# todo tidy
# todo extract_parameter_set_dials
38 changes: 37 additions & 1 deletion R/tailor.R
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,42 @@ set_tailor_type <- function(object, y, call = caller_env()) {
# todo: where to validate #levels?
# todo setup eval_time
# todo missing methods:
# todo tune_args

#' @export
tune_args.tailor <- function(object, full = FALSE, ...) {
adjustments <- object$adjustments

if (length(adjustments) == 0L) {
return(tune_tbl())
}

res <- purrr::map(object$adjustments, tune_args, full = full)
res <- purrr::list_rbind(res)

tune_tbl(
res$name,
res$tunable,
res$id,
res$source,
res$component,
res$component_id,
full = full
)
}

#' @export
tunable.tailor <- function(x, ...) {
if (length(x$adjustments) == 0) {
res <- no_param
} else {
res <- purrr::map(x$adjustments, tunable)
res <- vctrs::vec_rbind(!!!res)
if (nrow(res) > 0) {
res <- res[!is.na(res$name), ]
}
}
res
}

# todo tidy (this should probably just be `adjustment_orderings()`)
# todo extract_parameter_set_dials
69 changes: 67 additions & 2 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,80 @@
#' @name tailor-internals
NULL


# tuning machinery -------------------------------------------------------------
is_tune <- function(x) {
if (!is.call(x)) {
return(FALSE)
}
isTRUE(identical(quote(tune), x[[1]]))
}

# for adjustments with no tunable parameters
tune_tbl <- function(name = character(), tunable = logical(), id = character(),
source = character(), component = character(),
component_id = character(), full = FALSE, call = caller_env()) {
complete_id <- id[!is.na(id)]
dups <- duplicated(complete_id)

if (any(dups)) {
offenders <- unique(complete_id[dups])
cli::cli_abort(
"{.val {offenders}} {?has a/have} duplicate {.field id} value{?s}.",
call = call
)
}

vry_tbl <-
tibble::new_tibble(list(
name = as.character(name),
tunable = as.logical(tunable),
id = as.character(id),
source = as.character(source),
component = as.character(component),
component_id = as.character(component_id)
))

if (!full) {
vry_tbl <- vry_tbl[vry_tbl$tunable, ]
}

vry_tbl
}

#' @export
tune_args.adjustment <- function(object, full = FALSE, ...) {
adjustment_id <- object$id
# Grab the adjustment class before the subset, as that removes the class
adjustment_type <- class(object)[1]

tune_param_list <- tunable(object)$name

# remove the non-tunable arguments as they are not important
object <- object[tune_param_list]

# Remove NULL argument adjustments. These are reserved
# for deprecated args or those set at fit() time.
object <- object[!purrr::map_lgl(object, is.null)]

res <- purrr::map_chr(object, find_tune_id)
res <- ifelse(res == "", names(res), res)

tune_tbl(
name = names(res),
tunable = unname(!is.na(res)),
id = unname(res),
source = "tailor",
component = adjustment_type,
component_id = adjustment_id,
full = full
)
}

#' @export
tunable.adjustment <- function(x, ...) {
no_param
}

# for adjustments with no tunable parameters
no_param <-
tibble::tibble(
name = character(0),
Expand All @@ -25,6 +89,7 @@ no_param <-
component_id = character(0)
)

# new_adjustment -------------------------------------------------------------
# These values are used to specify "what will we need for the adjustment?" and
# "what will we change?". For the outputs, we cannot change the probabilities
# without changing the classes. This is important because we are going to have
Expand Down
15 changes: 15 additions & 0 deletions tests/testthat/test-adjust-equivocal-zone.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,18 @@ test_that("adjustment printing", {
expect_snapshot(tailor() %>% adjust_equivocal_zone())
expect_snapshot(tailor() %>% adjust_equivocal_zone(hardhat::tune()))
})

test_that("tunable", {
tlr <-
tailor() %>%
adjust_equivocal_zone(value = 1 / 4)
adj_param <- tunable(tlr$adjustments[[1]])
expect_equal(adj_param$name, c("buffer"))
expect_true(all(adj_param$source == "tailor"))
expect_true(is.list(adj_param$call_info))
expect_equal(nrow(adj_param), 1)
expect_equal(
names(adj_param),
c("name", "call_info", "source", "component", "component_id")
)
})
8 changes: 8 additions & 0 deletions tests/testthat/test-adjust-numeric-calibration.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,11 @@ test_that("errors informatively with bad input", {
expect_no_condition(adjust_numeric_calibration(tailor()))
expect_no_condition(adjust_numeric_calibration(tailor(), "linear"))
})

test_that("tunable", {
tlr <-
tailor() %>%
adjust_numeric_calibration(method = "linear")
adj_param <- tunable(tlr$adjustments[[1]])
expect_equal(adj_param, no_param)
})
14 changes: 14 additions & 0 deletions tests/testthat/test-adjust-numeric-range.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,17 @@ test_that("adjustment printing", {
expect_snapshot(tailor() %>% adjust_numeric_range(hardhat::tune(), 1))
})

test_that("tunable", {
tlr <-
tailor() %>%
adjust_numeric_range(lower_limit = 1, upper_limit = 2)
adj_param <- tunable(tlr$adjustments[[1]])
expect_equal(adj_param$name, c("lower_limit", "upper_limit"))
expect_true(all(adj_param$source == "tailor"))
expect_true(is.list(adj_param$call_info))
expect_equal(nrow(adj_param), 2)
expect_equal(
names(adj_param),
c("name", "call_info", "source", "component", "component_id")
)
})
8 changes: 8 additions & 0 deletions tests/testthat/test-adjust-predictions-custom.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,11 @@ test_that("basic adjust_predictions_custom() usage works", {
test_that("adjustment printing", {
expect_snapshot(tailor() %>% adjust_predictions_custom())
})

test_that("tunable", {
tlr <-
tailor() %>%
adjust_predictions_custom(linear_predictor = binomial()$linkfun(Class2))
adj_param <- tunable(tlr$adjustments[[1]])
expect_equal(adj_param, no_param)
})
8 changes: 8 additions & 0 deletions tests/testthat/test-adjust-probability-calibration.R
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,11 @@ test_that("errors informatively with bad input", {
expect_no_condition(adjust_numeric_calibration(tailor()))
expect_no_condition(adjust_numeric_calibration(tailor(), "linear"))
})

test_that("tunable", {
tlr <-
tailor() %>%
adjust_probability_calibration(method = "logistic")
adj_param <- tunable(tlr$adjustments[[1]])
expect_equal(adj_param, no_param)
})
15 changes: 15 additions & 0 deletions tests/testthat/test-adjust-probability-threshold.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,18 @@ test_that("adjustment printing", {
expect_snapshot(tailor() %>% adjust_probability_threshold())
expect_snapshot(tailor() %>% adjust_probability_threshold(hardhat::tune()))
})

test_that("tunable", {
tlr <-
tailor() %>%
adjust_probability_threshold(.1)
adj_param <- tunable(tlr$adjustments[[1]])
expect_equal(adj_param$name, "threshold")
expect_true(all(adj_param$source == "tailor"))
expect_true(is.list(adj_param$call_info))
expect_equal(nrow(adj_param), 1)
expect_equal(
names(adj_param),
c("name", "call_info", "source", "component", "component_id")
)
})
30 changes: 30 additions & 0 deletions tests/testthat/test-tailor.R
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,33 @@ test_that("error informatively with empty tidyselections", {
)
)
})

test_that("tunable (no adjustments)", {
tlr <-
tailor()

tlr_param <- tunable(tlr)
expect_equal(tlr_param, no_param)
})

test_that("tunable (multiple adjustments)", {
tlr <-
tailor() %>%
adjust_probability_threshold(.2) %>%
adjust_equivocal_zone()

tlr_param <- tunable(tlr)
expect_equal(tlr_param$name, c("threshold", "buffer"))
expect_true(all(tlr_param$source == "tailor"))
expect_true(is.list(tlr_param$call_info))
expect_equal(nrow(tlr_param), 2)
expect_equal(
names(tlr_param),
c("name", "call_info", "source", "component", "component_id")
)

expect_equal(
tlr_param,
bind_rows(tunable(tlr$adjustments[[1]]), tunable(tlr$adjustments[[2]]))
)
})

0 comments on commit 95e28be

Please sign in to comment.