Skip to content

Commit

Permalink
fit calibrators at fit.container() (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch authored May 2, 2024
1 parent 71ed887 commit 071749e
Show file tree
Hide file tree
Showing 13 changed files with 222 additions and 96 deletions.
52 changes: 33 additions & 19 deletions R/adjust-numeric-calibration.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#' Re-calibrate numeric predictions
#'
#' @param x A [container()].
#' @param calibrator A pre-trained calibration method from the \pkg{probably}
#' package, such as [probably::cal_estimate_linear()].
#' @param type Character. One of `"linear"`, `"isotonic"`, or
#' `"isotonic_boot"`, corresponding to the function from the \pkg{probably}
#' package [probably::cal_estimate_linear()],
#' [probably::cal_estimate_isotonic()], or
#' [probably::cal_estimate_isotonic_boot()], respectively.
#' @examples
#' library(modeldata)
#' library(probably)
Expand All @@ -14,27 +17,24 @@
#'
#' dat
#'
#' # calibrate numeric predictions
#' reg_cal <- cal_estimate_linear(dat, truth = y, estimate = y_pred)
#'
#' # specify calibration
#' reg_ctr <-
#' container(mode = "regression") %>%
#' adjust_numeric_calibration(reg_cal)
#' adjust_numeric_calibration(type = "linear")
#'
#' # "train" container
#' # train container
#' reg_ctr_trained <- fit(reg_ctr, dat, outcome = y, estimate = y_pred)
#'
#' predict(reg_ctr, dat)
#' predict(reg_ctr_trained, dat)
#' @export
adjust_numeric_calibration <- function(x, calibrator) {
check_container(x)
check_required(calibrator)
if (!inherits(calibrator, "cal_regression")) {
cli_abort(
"{.arg calibrator} should be a \\
{.help [<cal_regression> object](probably::cal_estimate_linear)}, \\
not {.obj_type_friendly {calibrator}}."
adjust_numeric_calibration <- function(x, type = NULL) {
# to-do: add argument specifying `prop` in initial_split
check_container(x, calibration_type = "numeric")
# wait to `check_type()` until `fit()` time
if (!is.null(type)) {
arg_match0(
type,
c("linear", "isotonic", "isotonic_boot")
)
}

Expand All @@ -43,7 +43,7 @@ adjust_numeric_calibration <- function(x, calibrator) {
"numeric_calibration",
inputs = "numeric",
outputs = "numeric",
arguments = list(calibrator = calibrator),
arguments = list(type = type),
results = list(),
trained = FALSE
)
Expand All @@ -67,19 +67,33 @@ print.numeric_calibration <- function(x, ...) {

#' @export
fit.numeric_calibration <- function(object, data, container = NULL, ...) {
type <- check_type(object$type, container$type)
# todo: adjust_numeric_calibration() should take arguments to pass to
# cal_estimate_* via dots
fit <-
eval_bare(
call2(
paste0("cal_estimate_", type),
.data = data,
truth = container$columns$outcome,
estimate = container$columns$estimate,
.ns = "probably"
)
)

new_operation(
class(object),
inputs = object$inputs,
outputs = object$outputs,
arguments = object$arguments,
results = list(),
results = list(fit = fit),
trained = TRUE
)
}

#' @export
predict.numeric_calibration <- function(object, new_data, container, ...) {
probably::cal_apply(new_data, object$argument$calibrator)
probably::cal_apply(new_data, object$results$fit)
}

# todo probably needs required_pkgs methods for cal objects
Expand Down
45 changes: 31 additions & 14 deletions R/adjust-probability-calibration.R
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
#' Re-calibrate classification probability predictions
#'
#' @param x A [container()].
#' @param calibrator A pre-trained calibration method from the \pkg{probably}
#' package, such as [probably::cal_estimate_logistic()].
#' @param type Character. One of `"logistic"`, `"multinomial"`,
#' `"beta"`, `"isotonic"`, or `"isotonic_boot"`, corresponding to the
#' function from the \pkg{probably} package [probably::cal_estimate_logistic()],
#' [probably::cal_estimate_multinomial()], etc., respectively.
#' @export
adjust_probability_calibration <- function(x, calibrator) {
check_container(x)
cls <- c("cal_binary", "cal_multinomial")
check_required(calibrator)
if (!inherits_any(calibrator, cls)) {
cli_abort(
"{.arg calibrator} should be a \\
{.help [<cal_binary> or <cal_multinomial> object](probably::cal_estimate_logistic)}, \\
not {.obj_type_friendly {calibrator}}."
adjust_probability_calibration <- function(x, type = NULL) {
# to-do: add argument specifying `prop` in initial_split
check_container(x, calibration_type = "probability")
# wait to `check_type()` until `fit()` time
if (!is.null(type)) {
arg_match(
type,
c("logistic", "multinomial", "beta", "isotonic", "isotonic_boot")
)
}

Expand All @@ -21,7 +22,7 @@ adjust_probability_calibration <- function(x, calibrator) {
"probability_calibration",
inputs = "probability",
outputs = "probability_class",
arguments = list(calibrator = calibrator),
arguments = list(type = type),
results = list(),
trained = FALSE
)
Expand All @@ -45,19 +46,35 @@ print.probability_calibration <- function(x, ...) {

#' @export
fit.probability_calibration <- function(object, data, container = NULL, ...) {
type <- check_type(object$type, container$type)
# todo: adjust_probability_calibration() should take arguments to pass to
# cal_estimate_* via dots
# to-do: add argument specifying `prop` in initial_split
fit <-
eval_bare(
call2(
paste0("cal_estimate_", type),
.data = data,
# todo: make getters for the entries in `columns`
truth = container$columns$outcome,
estimate = container$columns$estimate,
.ns = "probably"
)
)

new_operation(
class(object),
inputs = object$inputs,
outputs = object$outputs,
arguments = object$arguments,
results = list(),
results = list(fit = fit),
trained = TRUE
)
}

#' @export
predict.probability_calibration <- function(object, new_data, container, ...) {
probably::cal_apply(new_data, object$argument$calibrator)
probably::cal_apply(new_data, object$results$fit)
}

# todo probably needs required_pkgs methods for cal objects
Expand Down
2 changes: 1 addition & 1 deletion R/container.R
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ fit.container <- function(object, .data, outcome, estimate, probabilities = c(),

num_oper <- length(object$operations)
for (op in seq_len(num_oper)) {
object$operations[[op]] <- fit(object$operations[[op]], data, object)
object$operations[[op]] <- fit(object$operations[[op]], .data, object)
.data <- predict(object$operations[[op]], .data, object)
}

Expand Down
85 changes: 83 additions & 2 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,95 @@ is_container <- function(x) {
}

# ad-hoc checking --------------------------------------------------------------
check_container <- function(x, call = caller_env(), arg = caller_arg(x)) {
check_container <- function(x, calibration_type = NULL, call = caller_env(), arg = caller_arg(x)) {
if (!is_container(x)) {
cli::cli_abort(
cli_abort(
"{.arg {arg}} should be a {.help [{.cls container}](container::container)}, \\
not {.obj_type_friendly {x}}.",
call = call
)
}

# check that the type of calibration ("numeric" or "probability") is
# compatible with the container type
if (!is.null(calibration_type)) {
container_type <- x$type
switch(
container_type,
regression =
check_calibration_type(calibration_type, "numeric", container_type, call = call),
binary = , multinomial =
check_calibration_type(calibration_type, "probability", container_type, call = call)
)
}

invisible()
}

check_calibration_type <- function(calibration_type, calibration_type_expected,
container_type, call) {
if (!identical(calibration_type, calibration_type_expected)) {
cli_abort(
"A {.field {container_type}} container is incompatible with the operation \\
{.fun {paste0('adjust_', calibration_type, '_calibration')}}.",
call = call
)
}
}

types_regression <- c("linear", "isotonic", "isotonic_boot")
types_binary <- c("logistic", "beta", "isotonic", "isotonic_boot")
types_multiclass <- c("multinomial", "beta", "isotonic", "isotonic_boot")
# a check function to be called when a container is being `fit()`ted.
# by the time a container is fitted, we have:
# * `adjust_type`, the `type` argument passed to an `adjust_*` function
# * this argument has already been checked to agree with the kind of
# `adjust_*()` function via `arg_match0()`.
# * `container_type`, the `type` argument either specified in `container()`
# or inferred in `fit.container()`.
check_type <- function(adjust_type,
container_type,
arg = caller_arg(adjust_type),
call = caller_env()) {
# if no `adjust_type` was supplied, infer a reasonable one based on the
# `container_type`
if (is.null(adjust_type)) {
switch(
container_type,
regression = return("linear"),
binary = return("logistic"),
multiclass = return("multinomial")
)
}

switch(
container_type,
regression = arg_match0(
adjust_type,
types_regression,
arg_nm = arg,
error_call = call
),
binary = arg_match0(
adjust_type,
types_binary,
arg_nm = arg,
error_call = call
),
multiclass = arg_match0(
adjust_type,
types_multiclass,
arg_nm = arg,
error_call = call
),
arg_match0(
adjust_type,
unique(c(types_regression, types_binary, types_multiclass)),
arg_nm = arg,
error_call = call
)
)

adjust_type
}

18 changes: 9 additions & 9 deletions man/adjust_numeric_calibration.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 5 additions & 3 deletions man/adjust_probability_calibration.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 9 additions & 8 deletions tests/testthat/_snaps/adjust-numeric-calibration.md
Original file line number Diff line number Diff line change
@@ -1,35 +1,36 @@
# adjustment printing

Code
ctr_reg %>% adjust_numeric_calibration(dummy_reg_cal)
ctr_reg %>% adjust_numeric_calibration()
Message
-- Container -------------------------------------------------------------------
A postprocessor with 1 operation:
A regression postprocessor with 1 operation:
* Re-calibrate numeric predictions.

# errors informatively with bad input

Code
adjust_numeric_calibration(ctr_reg)
adjust_numeric_calibration(ctr_reg, "boop")
Condition
Error in `adjust_numeric_calibration()`:
! `calibrator` is absent but must be supplied.
! `type` must be one of "linear", "isotonic", or "isotonic_boot", not "boop".

---

Code
adjust_numeric_calibration(ctr_reg, "boop")
container("classification", "binary") %>% adjust_numeric_calibration("linear")
Condition
Error in `adjust_numeric_calibration()`:
! `calibrator` should be a <cal_regression> object (`?probably::cal_estimate_linear()`), not a string.
! A binary container is incompatible with the operation `adjust_numeric_calibration()`.

---

Code
adjust_numeric_calibration(ctr_cls, dummy_cls_cal)
container("regression", "regression") %>% adjust_numeric_calibration("binary")
Condition
Error in `adjust_numeric_calibration()`:
! `calibrator` should be a <cal_regression> object (`?probably::cal_estimate_linear()`), not a <cal_binary> object.
! `type` must be one of "linear", "isotonic", or "isotonic_boot", not "binary".
i Did you mean "linear"?

Loading

0 comments on commit 071749e

Please sign in to comment.