Skip to content

Commit

Permalink
generate needed objects from split first-thing (#910)
Browse files Browse the repository at this point in the history
Generates the analysis, potato, and assessment sets first-thing in tune_grid_loop_iter(), as well as needed labels and prediction indices, and then no longer needs to reference split for the rest of the function body. As such, refers to those sets only by their names and removes references to split and split_orig in explanatory comments.
  • Loading branch information
simonpcouch authored May 31, 2024
1 parent b42ef28 commit 02e5632
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 30 deletions.
49 changes: 25 additions & 24 deletions R/grid_code_paths.R
Original file line number Diff line number Diff line change
Expand Up @@ -342,10 +342,6 @@ tune_grid_loop_iter <- function(split,
metrics_info = metrics_info(metrics),
params,
split_args = NULL) {
# `split` may be overwritten later on to create an "inner" split for
# post-processing. however, we want the original split to persist so we can
# use it (particularly `labels(split_orig)`) in logging
split_orig <- split
split_labels <- labels(split)

load_pkgs(workflow)
Expand Down Expand Up @@ -384,30 +380,43 @@ tune_grid_loop_iter <- function(split,
model_param_names <- model_params$id
preprocessor_param_names <- preprocessor_params$id

analysis <- rsample::analysis(split)
# inline rsample::assessment so that we can pass indices to `predict_model()`
assessment_rows <- as.integer(split, data = "assessment")
assessment <- vctrs::vec_slice(split$data, assessment_rows)

if (workflows::should_inner_split(workflow)) {
# if the workflow has a postprocessor that needs training (i.e. calibration),
# further split the analysis data into an "inner" analysis and
# assessment set.
# * the preprocessor and model (excluding the post-processor) are fitted
# on `analysis(split)`, the inner analysis set (just referred to as analysis)
# * that model generates predictions on `assessment(split)`, the
# potato set
# on `analysis(inner_split(split))`, the inner analysis set (just
# referred to as analysis)
# * that model generates predictions on `assessment(inner_split(split))`,
# the potato set
# * the post-processor is trained on the predictions generated from the
# potato set
# * the model (including the post-processor) generates predictions on the
# assessment set (not inner, i.e. `assessment(split_orig)`) and those
# predictions are assessed with performance metrics
# assessment set and those predictions are assessed with performance metrics
# todo: check if workflow's `method` is incompatible with `class(split)`?
# todo: workflow's `method` is currently ignored in favor of the one
# automatically dispatched to from `split`. consider this is combination
# with above todo.
split_args <- c(split_args, list(prop = workflow$post$actions$tailor$prop))
split <- rsample::inner_split(split, split_args = split_args)
analysis <- rsample::analysis(split)

# inline rsample::assessment so that we can pass indices to `predict_model()`
potato_rows <- as.integer(split, data = "assessment")
potato <- vctrs::vec_slice(split$data, potato_rows)
} else {
analysis <- rsample::analysis(split)

potato_rows <- NULL
potato <- NULL
}

rm(split)

# ----------------------------------------------------------------------------
# Preprocessor loop

Expand Down Expand Up @@ -511,7 +520,8 @@ tune_grid_loop_iter <- function(split,
iter_msg_predictions <- paste(iter_msg_model, "(predictions)")

iter_predictions <- .catch_and_log(
predict_model(split, workflow, iter_grid, metrics, iter_submodels,
predict_model(potato %||% assessment, potato_rows %||% assessment_rows,
workflow, iter_grid, metrics, iter_submodels,
metrics_info = metrics_info, eval_time = eval_time),
control,
split_labels,
Expand All @@ -528,18 +538,11 @@ tune_grid_loop_iter <- function(split,
if (workflows::should_inner_split(workflow)) {
# note that, since we're training a postprocessor, `iter_predictions`
# are the predictions from the potato set rather than the
# assessment set (i.e. `assessment(split_orig)`)
# assessment set

# train the post-processor on the predictions generated from the model
# on the potato set
# todo: this is the same assessment set that `predict_model` makes.
# we're ad-hoc `augment()`ing here, but would be nice to just have
# those predictors
# todo: needs a `.catch_and_log`
# todo: .fit_post currently takes in `assessment(split)` rather than
# a set of predictions, meaning that we predict on `assessment(split)`
# twice :(
potato <- rsample::assessment(split)
workflow_with_post <- .fit_post(workflow, potato)

workflow_with_post <- .fit_finalize(workflow_with_post)
Expand All @@ -556,13 +559,11 @@ tune_grid_loop_iter <- function(split,
elt_extract <- make_extracts(elt_extract, iter_grid, split_labels, .config = iter_config)
out_extracts <- append_extracts(out_extracts, elt_extract)


# generate predictions on the assessment set (not inner,
# i.e. `assessment(split_orig)`) from the model and apply the
# generate predictions on the assessment set from the model and apply the
# post-processor to those predictions to generate updated predictions
iter_predictions <- .catch_and_log(
predict_model(split_orig, workflow_with_post, iter_grid, metrics,
iter_submodels, metrics_info = metrics_info,
predict_model(assessment, assessment_rows, workflow_with_post, iter_grid,
metrics, iter_submodels, metrics_info = metrics_info,
eval_time = eval_time),
control,
split_labels,
Expand Down
8 changes: 2 additions & 6 deletions R/grid_helpers.R
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@

predict_model <- function(split, workflow, grid, metrics, submodels = NULL,
metrics_info, eval_time = NULL) {
predict_model <- function(new_data, orig_rows, workflow, grid, metrics,
submodels = NULL, metrics_info, eval_time = NULL) {

model <- extract_fit_parsnip(workflow)

new_data <- rsample::assessment(split)

forged <- forge_from_workflow(new_data, workflow)
x_vals <- forged$predictors
y_vals <- forged$outcomes
Expand All @@ -16,8 +14,6 @@ predict_model <- function(split, workflow, grid, metrics, submodels = NULL,
model$preproc$y_var <- names(y_vals)
}

orig_rows <- as.integer(split, data = "assessment")

if (length(orig_rows) != nrow(x_vals)) {
msg <- paste0(
"Some assessment set rows are not available at ",
Expand Down
2 changes: 2 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ new_bare_tibble <- function(x, ..., class = character()) {
.iter
}

`%||%` <- function (x, y) {if (rlang::is_null(x)) y else x}

## -----------------------------------------------------------------------------

#' Various accessor functions
Expand Down

0 comments on commit 02e5632

Please sign in to comment.