From 233e12e6440ebec93125e948d50e2f53a9e6274c Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Fri, 20 Sep 2024 17:03:14 -0700 Subject: [PATCH 01/13] remove extraneous code in favour of NextMethod --- R/epi_recipe.R | 147 ++++++++--------------------------------------- R/epi_workflow.R | 21 +++---- 2 files changed, 32 insertions(+), 136 deletions(-) diff --git a/R/epi_recipe.R b/R/epi_recipe.R index c3a18d3cb..7b3fc8fdb 100644 --- a/R/epi_recipe.R +++ b/R/epi_recipe.R @@ -95,23 +95,10 @@ add_epi_recipe <- function( #' @rdname add_epi_recipe #' @export remove_epi_recipe <- function(x) { - workflows:::validate_is_workflow(x) - - if (!workflows:::has_preprocessor_recipe(x)) { - rlang::warn("The workflow has no recipe preprocessor to remove.") - } - - actions <- x$pre$actions - actions[["recipe"]] <- NULL - - new_epi_workflow( - pre = workflows:::new_stage_pre(actions = actions), - fit = x$fit, - post = x$post, - trained = FALSE - ) + workflows::remove_recipe(x) } + #' @rdname add_epi_recipe #' @export update_epi_recipe <- function(x, recipe, ..., blueprint = default_epi_recipe_blueprint()) { @@ -180,15 +167,21 @@ adjust_epi_recipe <- function(x, which_step, ..., blueprint = default_epi_recipe #' @rdname adjust_epi_recipe #' @export -adjust_epi_recipe.epi_workflow <- function(x, which_step, ..., blueprint = default_epi_recipe_blueprint()) { - recipe <- adjust_epi_recipe(workflows::extract_preprocessor(x), which_step, ...) +adjust_epi_recipe.epi_workflow <- function( + x, which_step, ..., blueprint = default_epi_recipe_blueprint() +) { - update_epi_recipe(x, recipe, blueprint = blueprint) + rec <- adjust_epi_recipe( + workflows::extract_preprocessor(x), which_step, ... + ) + update_epi_recipe(x, rec, blueprint = blueprint) } #' @rdname adjust_epi_recipe #' @export -adjust_epi_recipe.epi_recipe <- function(x, which_step, ..., blueprint = default_epi_recipe_blueprint()) { +adjust_epi_recipe.epi_recipe <- function( + x, which_step, ..., blueprint = default_epi_recipe_blueprint() +) { if (!(is.numeric(which_step) || is.character(which_step))) { cli::cli_abort( c("`which_step` must be a number or a character.", @@ -294,109 +287,17 @@ kill_levels <- function(x, keys) { #' @export print.epi_recipe <- function(x, form_width = 30, ...) { - cli::cli_div(theme = list(.pkg = list("vec-trunc" = Inf, "vec-last" = ", "))) - - cli::cli_h1("Epi Recipe") - cli::cli_h3("Inputs") - - tab <- table(x$var_info$role, useNA = "ifany") - tab <- stats::setNames(tab, names(tab)) - names(tab)[is.na(names(tab))] <- "undeclared role" - - roles <- c("outcome", "predictor", "case_weights", "undeclared role") - - tab <- c( - tab[names(tab) == roles[1]], - tab[names(tab) == roles[2]], - tab[names(tab) == roles[3]], - sort(tab[!names(tab) %in% roles], TRUE), - tab[names(tab) == roles[4]] - ) - - cli::cli_text("Number of variables by role") - - spaces_needed <- max(nchar(names(tab))) - nchar(names(tab)) + - max(nchar(tab)) - nchar(tab) - - cli::cli_verbatim( - glue::glue("{names(tab)}: {strrep('\ua0', spaces_needed)}{tab}") - ) - - if ("tr_info" %in% names(x)) { - cli::cli_h3("Training information") - nmiss <- x$tr_info$nrows - x$tr_info$ncomplete - nrows <- x$tr_info$nrows - - cli::cli_text( - "Training data contained {nrows} data points and {cli::no(nmiss)} \\ - incomplete row{?s}." - ) - } - - if (!is.null(x$steps)) { - cli::cli_h3("Operations") - } - - fmt <- cli::cli_fmt({ - for (step in x$steps) { - print(step, form_width = form_width) - } - }) - cli::cli_ol(fmt) - cli::cli_end() - - invisible(x) -} - -# Currently only used in the workflow printing -print_preprocessor_recipe <- function(x, ...) { - recipe <- workflows::extract_preprocessor(x) - steps <- recipe$steps - n_steps <- length(steps) - cli::cli_text("{n_steps} Recipe step{?s}.") - - if (n_steps == 0L) { - return(invisible(x)) - } - - step_names <- map_chr(steps, workflows:::pull_step_name) - - if (n_steps <= 10L) { - cli::cli_ol(step_names) - return(invisible(x)) - } - - extra_steps <- n_steps - 10L - step_names <- step_names[1:10] - - cli::cli_ol(step_names) - cli::cli_bullets("... and {extra_steps} more step{?s}.") - invisible(x) -} - -print_preprocessor <- function(x) { - has_preprocessor_formula <- workflows:::has_preprocessor_formula(x) - has_preprocessor_recipe <- workflows:::has_preprocessor_recipe(x) - has_preprocessor_variables <- workflows:::has_preprocessor_variables(x) - - no_preprocessor <- !has_preprocessor_formula && !has_preprocessor_recipe && - !has_preprocessor_variables - - if (no_preprocessor) { - return(invisible(x)) - } - - cli::cli_rule("Preprocessor") - cli::cli_text("") - - if (has_preprocessor_formula) { - workflows:::print_preprocessor_formula(x) - } - if (has_preprocessor_recipe) { - print_preprocessor_recipe(x) - } - if (has_preprocessor_variables) { - workflows:::print_preprocessor_variables(x) - } + o <- cli::cli_fmt(NextMethod()) + # Fix up the recipe name + rr <- unlist(strsplit(o[2], "Recipe")) + len <- nchar(rr[2]) + h1_tail <- paste0(substr(rr[2], 1, len / 2 - 10), substr(rr[2], len / 2, len)) + o[2] <- paste0(rr[1], "Epi Recipe", h1_tail) + + # Number the operations + ops <- seq(grep(" Operations ", o, fixed = TRUE) + 1, length(o)) + rep_ops <- sub("\033[36m•\033[39m ", "", o[ops], fixed = TRUE) # kills the • + o[ops] <- paste0(ops - ops[1] + 1, ". ", rep_ops) + cli::cli_bullets(o) invisible(x) } diff --git a/R/epi_workflow.R b/R/epi_workflow.R index 369b96eb1..8ce86a9f6 100644 --- a/R/epi_workflow.R +++ b/R/epi_workflow.R @@ -32,18 +32,13 @@ #' #' wf epi_workflow <- function(preprocessor = NULL, spec = NULL, postprocessor = NULL) { - out <- workflows::workflow(spec = spec) - class(out) <- c("epi_workflow", class(out)) + out <- workflows::workflow(preprocessor = preprocessor, spec = spec) - if (is_epi_recipe(preprocessor)) { - out <- add_epi_recipe(out, preprocessor) - } else if (!is_null(preprocessor)) { - out <- workflows:::add_preprocessor(out, preprocessor) - } if (!is_null(postprocessor)) { out <- add_postprocessor(out, postprocessor) } + class(out) <- c("epi_workflow", class(out)) out } @@ -162,11 +157,14 @@ predict.epi_workflow <- function(object, new_data, type = NULL, opts = list(), . } components <- list() components$mold <- workflows::extract_mold(object) - components$forged <- hardhat::forge(new_data, + components$forged <- hardhat::forge( + new_data, blueprint = components$mold$blueprint ) components$keys <- grab_forged_keys(components$forged, object, new_data) - components <- apply_frosting(object, components, new_data, type = type, opts = opts, ...) + components <- apply_frosting( + object, components, new_data, type = type, opts = opts, ... + ) components$predictions } @@ -216,10 +214,7 @@ new_epi_workflow <- function( #' @export print.epi_workflow <- function(x, ...) { - print_header(x) - print_preprocessor(x) - # workflows:::print_case_weights(x) - print_model(x) + NextMethod() print_postprocessor(x) invisible(x) } From 79ec6161eb044e0109b5a5fc73110b6eb7f38bfb Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Sat, 21 Sep 2024 10:34:06 -0700 Subject: [PATCH 02/13] everything is broken --- R/epi_recipe.R | 5 ++++- R/epi_workflow.R | 1 - tests/testthat/test-epi_workflow.R | 1 + 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/R/epi_recipe.R b/R/epi_recipe.R index 7b3fc8fdb..b7534da05 100644 --- a/R/epi_recipe.R +++ b/R/epi_recipe.R @@ -95,7 +95,9 @@ add_epi_recipe <- function( #' @rdname add_epi_recipe #' @export remove_epi_recipe <- function(x) { - workflows::remove_recipe(x) + wf <- workflows::remove_recipe(x) + class(wf) <- c("epi_workflow", class(wf)) + wf } @@ -222,6 +224,7 @@ prep.epi_recipe <- function( if (!strings_as_factors) { return(NextMethod("prep")) } + browser() # workaround to avoid converting strings2factors with recipes::prep.recipe() # We do the conversion here, then set it to FALSE training <- recipes:::check_training_set(training, x, fresh) diff --git a/R/epi_workflow.R b/R/epi_workflow.R index 8ce86a9f6..ae91803f4 100644 --- a/R/epi_workflow.R +++ b/R/epi_workflow.R @@ -96,7 +96,6 @@ fit.epi_workflow <- function(object, data, ..., control = workflows::control_wor as_of = attributes(data)$metadata$as_of ) object$original_data <- data - NextMethod() } diff --git a/tests/testthat/test-epi_workflow.R b/tests/testthat/test-epi_workflow.R index 94799faa1..d276953e5 100644 --- a/tests/testthat/test-epi_workflow.R +++ b/tests/testthat/test-epi_workflow.R @@ -59,6 +59,7 @@ test_that("model can be added/updated/removed from epi_workflow", { expect_equal(class(model_spec2), c("linear_reg", "model_spec")) wf <- remove_model(wf) + expect_equal(class(wf), c("epi_workflow", "workflow")) expect_error(extract_spec_parsnip(wf)) expect_equal(wf$fit$actions$model$spec, NULL) }) From 5d8fb879fb32434f845334842e014b3c6d6021e8 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Sat, 21 Sep 2024 10:42:51 -0700 Subject: [PATCH 03/13] simplify recipe printing --- R/epi_recipe.R | 115 ++++++------------------------------------------- 1 file changed, 12 insertions(+), 103 deletions(-) diff --git a/R/epi_recipe.R b/R/epi_recipe.R index c3a18d3cb..3a366391f 100644 --- a/R/epi_recipe.R +++ b/R/epi_recipe.R @@ -294,109 +294,18 @@ kill_levels <- function(x, keys) { #' @export print.epi_recipe <- function(x, form_width = 30, ...) { - cli::cli_div(theme = list(.pkg = list("vec-trunc" = Inf, "vec-last" = ", "))) - - cli::cli_h1("Epi Recipe") - cli::cli_h3("Inputs") - - tab <- table(x$var_info$role, useNA = "ifany") - tab <- stats::setNames(tab, names(tab)) - names(tab)[is.na(names(tab))] <- "undeclared role" - - roles <- c("outcome", "predictor", "case_weights", "undeclared role") - - tab <- c( - tab[names(tab) == roles[1]], - tab[names(tab) == roles[2]], - tab[names(tab) == roles[3]], - sort(tab[!names(tab) %in% roles], TRUE), - tab[names(tab) == roles[4]] - ) - - cli::cli_text("Number of variables by role") - - spaces_needed <- max(nchar(names(tab))) - nchar(names(tab)) + - max(nchar(tab)) - nchar(tab) - - cli::cli_verbatim( - glue::glue("{names(tab)}: {strrep('\ua0', spaces_needed)}{tab}") - ) - - if ("tr_info" %in% names(x)) { - cli::cli_h3("Training information") - nmiss <- x$tr_info$nrows - x$tr_info$ncomplete - nrows <- x$tr_info$nrows - - cli::cli_text( - "Training data contained {nrows} data points and {cli::no(nmiss)} \\ - incomplete row{?s}." - ) - } - - if (!is.null(x$steps)) { - cli::cli_h3("Operations") - } - - fmt <- cli::cli_fmt({ - for (step in x$steps) { - print(step, form_width = form_width) - } - }) - cli::cli_ol(fmt) - cli::cli_end() - - invisible(x) -} - -# Currently only used in the workflow printing -print_preprocessor_recipe <- function(x, ...) { - recipe <- workflows::extract_preprocessor(x) - steps <- recipe$steps - n_steps <- length(steps) - cli::cli_text("{n_steps} Recipe step{?s}.") - - if (n_steps == 0L) { - return(invisible(x)) - } - - step_names <- map_chr(steps, workflows:::pull_step_name) - - if (n_steps <= 10L) { - cli::cli_ol(step_names) - return(invisible(x)) - } - - extra_steps <- n_steps - 10L - step_names <- step_names[1:10] - - cli::cli_ol(step_names) - cli::cli_bullets("... and {extra_steps} more step{?s}.") + o <- cli::cli_fmt(NextMethod()) + # Fix up the recipe name + rr <- unlist(strsplit(o[2], "Recipe")) + len <- nchar(rr[2]) + h1_tail <- paste0(substr(rr[2], 1, len / 2 - 10), substr(rr[2], len / 2, len)) + o[2] <- paste0(rr[1], "Epi Recipe", h1_tail) + + # Number the operations + ops <- seq(grep(" Operations ", o, fixed = TRUE) + 1, length(o)) + rep_ops <- sub("\033[36m•\033[39m ", "", o[ops], fixed = TRUE) # kills the • + o[ops] <- paste0(ops - ops[1] + 1, ". ", rep_ops) + cli::cli_bullets(o) invisible(x) } -print_preprocessor <- function(x) { - has_preprocessor_formula <- workflows:::has_preprocessor_formula(x) - has_preprocessor_recipe <- workflows:::has_preprocessor_recipe(x) - has_preprocessor_variables <- workflows:::has_preprocessor_variables(x) - - no_preprocessor <- !has_preprocessor_formula && !has_preprocessor_recipe && - !has_preprocessor_variables - - if (no_preprocessor) { - return(invisible(x)) - } - - cli::cli_rule("Preprocessor") - cli::cli_text("") - - if (has_preprocessor_formula) { - workflows:::print_preprocessor_formula(x) - } - if (has_preprocessor_recipe) { - print_preprocessor_recipe(x) - } - if (has_preprocessor_variables) { - workflows:::print_preprocessor_variables(x) - } - invisible(x) -} From 1278cee5a3af0f656b70c3a023f7c1c405ed5aa3 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Sun, 22 Sep 2024 12:14:28 -0700 Subject: [PATCH 04/13] fix print.epi_recipe to pass tests --- R/epi_recipe.R | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/R/epi_recipe.R b/R/epi_recipe.R index 3a366391f..1a6a196cd 100644 --- a/R/epi_recipe.R +++ b/R/epi_recipe.R @@ -302,9 +302,13 @@ print.epi_recipe <- function(x, form_width = 30, ...) { o[2] <- paste0(rr[1], "Epi Recipe", h1_tail) # Number the operations - ops <- seq(grep(" Operations ", o, fixed = TRUE) + 1, length(o)) - rep_ops <- sub("\033[36m•\033[39m ", "", o[ops], fixed = TRUE) # kills the • - o[ops] <- paste0(ops - ops[1] + 1, ". ", rep_ops) + has_operations <- any(grepl(" Operations ", o, fixed = TRUE)) + if (has_operations) { + ops <- seq(grep(" Operations ", o, fixed = TRUE) + 1, length(o)) + # kills the \bullet + rep_ops <- sub("^\\033\\[36m.\\033\\[39m ", "", o[ops], perl = TRUE) + o[ops] <- paste0(ops - ops[1] + 1, ". ", rep_ops) + } cli::cli_bullets(o) invisible(x) } From fdb4b4dc5e435c09087f4a597838fe82f5294285 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Sun, 22 Sep 2024 12:20:35 -0700 Subject: [PATCH 05/13] remove browser() --- R/epi_recipe.R | 1 - 1 file changed, 1 deletion(-) diff --git a/R/epi_recipe.R b/R/epi_recipe.R index 4f9c84bbb..efa16b66a 100644 --- a/R/epi_recipe.R +++ b/R/epi_recipe.R @@ -224,7 +224,6 @@ prep.epi_recipe <- function( if (!strings_as_factors) { return(NextMethod("prep")) } - browser() # workaround to avoid converting strings2factors with recipes::prep.recipe() # We do the conversion here, then set it to FALSE training <- recipes:::check_training_set(training, x, fresh) From 604199670070ae6f71b3cfe67dbaa56a3513a7fd Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Sun, 22 Sep 2024 12:25:00 -0700 Subject: [PATCH 06/13] white space changes --- R/epi_workflow.R | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/R/epi_workflow.R b/R/epi_workflow.R index 369b96eb1..fd3cb82ee 100644 --- a/R/epi_workflow.R +++ b/R/epi_workflow.R @@ -162,11 +162,14 @@ predict.epi_workflow <- function(object, new_data, type = NULL, opts = list(), . } components <- list() components$mold <- workflows::extract_mold(object) - components$forged <- hardhat::forge(new_data, + components$forged <- hardhat::forge( + new_data, blueprint = components$mold$blueprint ) components$keys <- grab_forged_keys(components$forged, object, new_data) - components <- apply_frosting(object, components, new_data, type = type, opts = opts, ...) + components <- apply_frosting( + object, components, new_data, type = type, opts = opts, ... + ) components$predictions } From f20ef9bdc26dde48ad5a43e903a3691b14736791 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Sun, 22 Sep 2024 12:37:36 -0700 Subject: [PATCH 07/13] epi_workflow constructor works --- R/epi_workflow.R | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/R/epi_workflow.R b/R/epi_workflow.R index fd3cb82ee..56f2fde5a 100644 --- a/R/epi_workflow.R +++ b/R/epi_workflow.R @@ -32,18 +32,15 @@ #' #' wf epi_workflow <- function(preprocessor = NULL, spec = NULL, postprocessor = NULL) { - out <- workflows::workflow(spec = spec) - class(out) <- c("epi_workflow", class(out)) - + out <- workflows::workflow(preprocessor, spec = spec) if (is_epi_recipe(preprocessor)) { + out <- workflows::remove_recipe(out) out <- add_epi_recipe(out, preprocessor) - } else if (!is_null(preprocessor)) { - out <- workflows:::add_preprocessor(out, preprocessor) } + class(out) <- c("epi_workflow", class(out)) if (!is_null(postprocessor)) { out <- add_postprocessor(out, postprocessor) } - out } From 0cc285bd28f2d73eb7869ee011fa28ce2d5579ae Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Sun, 22 Sep 2024 13:44:51 -0700 Subject: [PATCH 08/13] everything passes --- R/epi_recipe.R | 18 +++--------------- R/epi_workflow.R | 12 ------------ R/model-methods.R | 15 +++------------ 3 files changed, 6 insertions(+), 39 deletions(-) diff --git a/R/epi_recipe.R b/R/epi_recipe.R index 1a6a196cd..068b64eb3 100644 --- a/R/epi_recipe.R +++ b/R/epi_recipe.R @@ -95,21 +95,9 @@ add_epi_recipe <- function( #' @rdname add_epi_recipe #' @export remove_epi_recipe <- function(x) { - workflows:::validate_is_workflow(x) - - if (!workflows:::has_preprocessor_recipe(x)) { - rlang::warn("The workflow has no recipe preprocessor to remove.") - } - - actions <- x$pre$actions - actions[["recipe"]] <- NULL - - new_epi_workflow( - pre = workflows:::new_stage_pre(actions = actions), - fit = x$fit, - post = x$post, - trained = FALSE - ) + x <- workflows::remove_recipe(x) + class(x) <- c("epi_workflow", class(x)) + x } #' @rdname add_epi_recipe diff --git a/R/epi_workflow.R b/R/epi_workflow.R index 56f2fde5a..326cb641c 100644 --- a/R/epi_workflow.R +++ b/R/epi_workflow.R @@ -201,18 +201,6 @@ augment.epi_workflow <- function(x, new_data, ...) { full_join(predictions, new_data, by = join_by) } -new_epi_workflow <- function( - pre = workflows:::new_stage_pre(), - fit = workflows:::new_stage_fit(), - post = workflows:::new_stage_post(), - trained = FALSE) { - out <- workflows:::new_workflow( - pre = pre, fit = fit, post = post, trained = trained - ) - class(out) <- c("epi_workflow", class(out)) - out -} - #' @export print.epi_workflow <- function(x, ...) { diff --git a/R/model-methods.R b/R/model-methods.R index 131a6ee91..26e410237 100644 --- a/R/model-methods.R +++ b/R/model-methods.R @@ -80,18 +80,9 @@ Add_model.epi_workflow <- function(x, spec, ..., formula = NULL) { #' @rdname Add_model #' @export Remove_model.epi_workflow <- function(x) { - workflows:::validate_is_workflow(x) - - if (!workflows:::has_spec(x)) { - rlang::warn("The workflow has no model to remove.") - } - - new_epi_workflow( - pre = x$pre, - fit = workflows:::new_stage_fit(), - post = x$post, - trained = FALSE - ) + x <- workflows::remove_model(x) + class(x) <- c("epi_workflow", class(x)) + x } #' @rdname Add_model From c4b7bcdfdf3ff9970e4d99b7638326382ff1228f Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Mon, 23 Sep 2024 08:43:01 -0700 Subject: [PATCH 09/13] frosting refactor complete --- NAMESPACE | 7 +++ R/frosting.R | 100 +++++++++++++++++++++------------ R/utils-misc.R | 2 + man/add_frosting.Rd | 9 +-- tests/testthat/test-frosting.R | 12 ++-- 5 files changed, 85 insertions(+), 45 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index c20b8c801..94d9981df 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -8,6 +8,8 @@ S3method(Remove_model,epi_workflow) S3method(Remove_model,workflow) S3method(Update_model,epi_workflow) S3method(Update_model,workflow) +S3method(add_frosting,default) +S3method(add_frosting,epi_workflow) S3method(adjust_epi_recipe,epi_recipe) S3method(adjust_epi_recipe,epi_workflow) S3method(adjust_frosting,epi_workflow) @@ -97,6 +99,8 @@ S3method(quantile,dist_quantiles) S3method(recipe,epi_df) S3method(recipes::recipe,formula) S3method(refresh_blueprint,default_epi_recipe_blueprint) +S3method(remove_frosting,default) +S3method(remove_frosting,epi_workflow) S3method(residuals,flatline) S3method(run_mold,default_epi_recipe_blueprint) S3method(slather,layer_add_forecast_date) @@ -119,6 +123,8 @@ S3method(tidy,check_enough_train_data) S3method(tidy,frosting) S3method(tidy,layer) S3method(update,layer) +S3method(update_frosting,default) +S3method(update_frosting,epi_workflow) S3method(vec_ptype_abbr,dist_quantiles) S3method(vec_ptype_full,dist_quantiles) S3method(weighted_interval_score,default) @@ -271,6 +277,7 @@ importFrom(rlang,":=") importFrom(rlang,abort) importFrom(rlang,arg_match) importFrom(rlang,as_function) +importFrom(rlang,caller_arg) importFrom(rlang,caller_env) importFrom(rlang,enquo) importFrom(rlang,enquos) diff --git a/R/frosting.R b/R/frosting.R index 6d1e9196c..f6f9fb0c5 100644 --- a/R/frosting.R +++ b/R/frosting.R @@ -1,6 +1,6 @@ -#' Add frosting to a workflow +#' Add frosting to an epi_workflow #' -#' @param x A workflow +#' @param x An epi_workflow #' @param frosting A frosting object created using `frosting()`. #' @param ... Not used. #' @@ -38,37 +38,68 @@ #' p3 <- predict(wf3, latest) #' p3 #' +#' add_frosting <- function(x, frosting, ...) { - rlang::check_dots_empty() - action <- workflows:::new_action_post(frosting = frosting) - epi_add_action(x, action, "frosting", ...) + UseMethod("add_frosting") } +#' @export +add_frosting.default <- function(x, frosting, ..., arg = caller_arg(x)) { + cli_abort("{x} must be a {.cls workflow}, not a {.cls {class(x)[1]}}.") +} -# Hacks around workflows `order_stage_post <- charcter(0)` ---------------- -epi_add_action <- function(x, action, name, ..., call = caller_env()) { - workflows:::validate_is_workflow(x, call = call) - add_action_frosting(x, action, name, ..., call = call) +#' @export +add_frosting.epi_workflow <- function(x, frosting, ...) { + rlang::check_dots_empty() + action <- structure( + list(frosting = frosting), + class = c("action_post", "action") + ) + if ("frosting" %in% names(x$post$actions)) { + cli_abort("A `frosting` action has already been added to this workflow.") + } + add_frosting_postprocessor(x, action) } -add_action_frosting <- function(x, action, name, ..., call = caller_env()) { - workflows:::check_singleton(x$post$actions, name, call = call) - x$post <- workflows:::add_action_to_stage(x$post, action, name, order_stage_frosting()) - x + +add_frosting_postprocessor <- function(wf, action) { + actions <- c(wf$post$actions, list(frosting = action)) + order <- intersect("frosting", names(actions)) + actions <- actions[order] + wf$post$actions <- actions + wf } -order_stage_frosting <- function() "frosting" + +# Hacks around workflows `order_stage_post <- charcter(0)` ---------------- +# epi_add_action <- function(x, action, name, ..., call = caller_env()) { +# workflows:::validate_is_workflow(x, call = call) +# add_action_frosting(x, action, name, ..., call = call) +# } +# add_action_frosting <- function(x, action, name, ..., call = caller_env()) { +# workflows:::check_singleton(x$post$actions, name, call = call) +# x$post <- workflows:::add_action_to_stage(x$post, action, name, order_stage_frosting()) +# x +# } +# order_stage_frosting <- function() "frosting" # End hacks. See cmu-delphi/epipredict#75 #' @rdname add_frosting #' @export -remove_frosting <- function(x) { - workflows:::validate_is_workflow(x) +remove_frosting <- function(x, ...) { + UseMethod("remove_frosting") +} + +#' @export +remove_frosting.default <- function(x, ..., arg = caller_arg(x)) { + cli_abort("{arg} must be an {.cls epi_workflow}, not a {.cls {class(x)[1]}}.") +} +#' @export +remove_frosting.epi_workflow <- function(x, ..., arg = caller_arg(x)) { if (!has_postprocessor_frosting(x)) { - rlang::warn("The workflow has no frosting postprocessor to remove.") + cli_warn("The epi_workflow {arg} has no frosting postprocessor to remove.") return(x) } - x$post$actions[["frosting"]] <- NULL x } @@ -85,11 +116,10 @@ validate_has_postprocessor <- function(x, ..., call = caller_env()) { rlang::check_dots_empty() has_postprocessor <- has_postprocessor_frosting(x) if (!has_postprocessor) { - message <- c( + cli_abort(c( "The workflow must have a frosting postprocessor.", i = "Provide one with `add_frosting()`." - ) - rlang::abort(message, call = call) + ), call = call) } invisible(x) } @@ -97,6 +127,16 @@ validate_has_postprocessor <- function(x, ..., call = caller_env()) { #' @rdname add_frosting #' @export update_frosting <- function(x, frosting, ...) { + UseMethod("update_frosting") +} + +#' @export +update_frosting.default <- function(x, frosting, ..., arg = caller_arg(x)) { + cli_abort("{arg} must be an {.cls epi_workflow}, not a {.cls {class(x)[1]}}.") +} + +#' @export +update_frosting.epi_workflow <- function(x, frosting, ...) { rlang::check_dots_empty() x <- remove_frosting(x) add_frosting(x, frosting) @@ -225,8 +265,8 @@ is_frosting <- function(x) { inherits(x, "frosting") } -#' @importFrom rlang caller_env -validate_frosting <- function(x, ..., arg = "`x`", call = caller_env()) { +#' @importFrom rlang caller_env caller_arg +validate_frosting <- function(x, ..., arg = caller_arg(x), call = caller_env()) { rlang::check_dots_empty() if (!is_frosting(x)) { cli_abort( @@ -237,16 +277,6 @@ validate_frosting <- function(x, ..., arg = "`x`", call = caller_env()) { invisible(x) } -new_frosting <- function() { - structure( - list( - layers = NULL, - requirements = NULL - ), - class = "frosting" - ) -} - #' Create frosting for postprocessing predictions #' @@ -289,11 +319,11 @@ new_frosting <- function() { #' p frosting <- function(layers = NULL, requirements = NULL) { if (!is_null(layers) || !is_null(requirements)) { - cli::cli_abort( + cli_abort( "Currently, no arguments to `frosting()` are allowed to be non-null." ) } - out <- new_frosting() + structure(list(layers = NULL, requirements = NULL), class = "frosting") } diff --git a/R/utils-misc.R b/R/utils-misc.R index af064b37c..7f5c54827 100644 --- a/R/utils-misc.R +++ b/R/utils-misc.R @@ -1,3 +1,5 @@ +`%nin%` <- function(x, table) match(x, table, nomatch = 0) == 0 + #' Check that newly created variable names don't overlap #' #' `check_pname` is to be used in a slather method to ensure that diff --git a/man/add_frosting.Rd b/man/add_frosting.Rd index 6c5b16769..b8a88807e 100644 --- a/man/add_frosting.Rd +++ b/man/add_frosting.Rd @@ -4,16 +4,16 @@ \alias{add_frosting} \alias{remove_frosting} \alias{update_frosting} -\title{Add frosting to a workflow} +\title{Add frosting to an epi_workflow} \usage{ add_frosting(x, frosting, ...) -remove_frosting(x) +remove_frosting(x, ...) update_frosting(x, frosting, ...) } \arguments{ -\item{x}{A workflow} +\item{x}{An epi_workflow} \item{frosting}{A frosting object created using \code{frosting()}.} @@ -23,7 +23,7 @@ update_frosting(x, frosting, ...) \code{x}, updated with a new frosting postprocessor } \description{ -Add frosting to a workflow +Add frosting to an epi_workflow } \examples{ library(dplyr) @@ -56,4 +56,5 @@ wf3 <- wf2 \%>\% remove_frosting() p3 <- predict(wf3, latest) p3 + } diff --git a/tests/testthat/test-frosting.R b/tests/testthat/test-frosting.R index 9c00e210d..69fc399df 100644 --- a/tests/testthat/test-frosting.R +++ b/tests/testthat/test-frosting.R @@ -1,15 +1,15 @@ test_that("frosting validators / constructors work", { wf <- epi_workflow() - expect_s3_class(new_frosting(), "frosting") - expect_true(is_frosting(new_frosting())) - expect_silent(epi_workflow(postprocessor = new_frosting())) + expect_s3_class(frosting(), "frosting") + expect_true(is_frosting(frosting())) + expect_silent(epi_workflow(postprocessor = frosting())) expect_false(has_postprocessor(wf)) expect_false(has_postprocessor_frosting(wf)) - expect_silent(wf %>% add_frosting(new_frosting())) - expect_silent(wf %>% add_postprocessor(new_frosting())) + expect_silent(wf %>% add_frosting(frosting())) + expect_silent(wf %>% add_postprocessor(frosting())) expect_error(wf %>% add_postprocessor(list())) - wf <- wf %>% add_frosting(new_frosting()) + wf <- wf %>% add_frosting(frosting()) expect_true(has_postprocessor(wf)) expect_true(has_postprocessor_frosting(wf)) }) From 63f2efdf2d17c158a5b8042a3a65ef1185029275 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Mon, 23 Sep 2024 08:59:56 -0700 Subject: [PATCH 10/13] remove most workflows::: --- R/extract.R | 2 +- R/frosting.R | 14 -------------- R/workflow-printing.R | 21 ++++++++++----------- 3 files changed, 11 insertions(+), 26 deletions(-) diff --git a/R/extract.R b/R/extract.R index e227b59b1..4a197aa8c 100644 --- a/R/extract.R +++ b/R/extract.R @@ -82,7 +82,7 @@ extract_argument.epi_workflow <- function(x, name, arg, ...) { rlang::check_dots_empty() type <- sub("_.*", "", name) if (type %in% c("check", "step")) { - if (!workflows:::has_preprocessor_recipe(x)) { + if ("recipe" %nin% names(x$pre$actions)) { cli_abort("The workflow must have a recipe preprocessor.") } out <- extract_argument(x$pre$actions$recipe$recipe, name, arg) diff --git a/R/frosting.R b/R/frosting.R index f6f9fb0c5..36ab75c41 100644 --- a/R/frosting.R +++ b/R/frosting.R @@ -69,20 +69,6 @@ add_frosting_postprocessor <- function(wf, action) { wf } -# Hacks around workflows `order_stage_post <- charcter(0)` ---------------- -# epi_add_action <- function(x, action, name, ..., call = caller_env()) { -# workflows:::validate_is_workflow(x, call = call) -# add_action_frosting(x, action, name, ..., call = call) -# } -# add_action_frosting <- function(x, action, name, ..., call = caller_env()) { -# workflows:::check_singleton(x$post$actions, name, call = call) -# x$post <- workflows:::add_action_to_stage(x$post, action, name, order_stage_frosting()) -# x -# } -# order_stage_frosting <- function() "frosting" -# End hacks. See cmu-delphi/epipredict#75 - - #' @rdname add_frosting #' @export remove_frosting <- function(x, ...) { diff --git a/R/workflow-printing.R b/R/workflow-printing.R index d9c3446f9..fa1a818ef 100644 --- a/R/workflow-printing.R +++ b/R/workflow-printing.R @@ -5,17 +5,16 @@ print_header <- function(x) { cli::cli_rule("Epi Workflow{trained}") cli::cli_end(d) - preprocessor_msg <- cli::style_italic("Preprocessor:") preprocessor <- dplyr::case_when( - workflows:::has_preprocessor_formula(x) ~ "Formula", - workflows:::has_preprocessor_recipe(x) ~ "Recipe", - workflows:::has_preprocessor_variables(x) ~ "Variables", + "formula" %in% names(x$pre$actions) ~ "Formula", + "recipe" %in% names(x$pre$actions) ~ "Recipe", + "variables" %in% names(x$pre$actions) ~ "Variables", TRUE ~ "None" ) cli::cli_text("{.emph Preprocessor:} {preprocessor}") - if (workflows:::has_spec(x)) { + if ("model" %in% names(x$fit$actions)) { spec <- class(workflows::extract_spec_parsnip(x))[[1]] spec <- glue::glue("{spec}()") } else { @@ -31,9 +30,9 @@ print_header <- function(x) { print_preprocessor <- function(x) { - has_preprocessor_formula <- workflows:::has_preprocessor_formula(x) - has_preprocessor_recipe <- workflows:::has_preprocessor_recipe(x) - has_preprocessor_variables <- workflows:::has_preprocessor_variables(x) + has_preprocessor_formula <- "formula" %in% names(x$pre$actions) + has_preprocessor_recipe <- "recipe" %in% names(x$pre$actions) + has_preprocessor_variables <- "variables" %in% names(x$pre$actions) no_preprocessor <- !has_preprocessor_formula && !has_preprocessor_recipe && !has_preprocessor_variables @@ -60,12 +59,12 @@ print_preprocessor <- function(x) { # revision of workflows:::print_model() print_model <- function(x) { - has_spec <- workflows:::has_spec(x) + has_spec <- "model" %in% names(x$fit$actions) if (!has_spec) { cli::cli_text("") return(invisible(x)) } - has_fit <- workflows:::has_fit(x) + has_fit <- !is.null(x$fit$fit) cli::cli_rule("Model") if (has_fit) { @@ -128,7 +127,7 @@ print_preprocessor_recipe <- function(x, ...) { return(invisible(x)) } - step_names <- map_chr(steps, workflows:::pull_step_name) + step_names <- map_chr(steps, ~ glue::glue("{class(.x)[[1]]}()")) if (n_steps <= 10L) { cli::cli_ol(step_names) From 2617698bd9446b48d0361ce0019864c4036caece Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Mon, 23 Sep 2024 09:39:18 -0700 Subject: [PATCH 11/13] refactor epi_workflow printing --- R/epi_workflow.R | 6 +- R/workflow-printing.R | 152 ------------------------------------------ 2 files changed, 5 insertions(+), 153 deletions(-) diff --git a/R/epi_workflow.R b/R/epi_workflow.R index 3b018e5e9..762b780db 100644 --- a/R/epi_workflow.R +++ b/R/epi_workflow.R @@ -203,7 +203,11 @@ augment.epi_workflow <- function(x, new_data, ...) { #' @export print.epi_workflow <- function(x, ...) { - NextMethod() + trained <- ifelse(workflows::is_trained_workflow(x), " [trained]", "") + header <- glue::glue("Epi Workflow{trained}") + txt <- utils::capture.output(NextMethod()) + txt[1] <- cli::rule(header, line = 2) + cli::cat_line(txt) print_postprocessor(x) invisible(x) } diff --git a/R/workflow-printing.R b/R/workflow-printing.R index fa1a818ef..04676837c 100644 --- a/R/workflow-printing.R +++ b/R/workflow-printing.R @@ -1,83 +1,3 @@ -print_header <- function(x) { - cli::cli_text("") - trained <- ifelse(workflows::is_trained_workflow(x), " [trained]", "") - d <- cli::cli_div(theme = list(rule = list("line-type" = "double"))) - cli::cli_rule("Epi Workflow{trained}") - cli::cli_end(d) - - preprocessor <- dplyr::case_when( - "formula" %in% names(x$pre$actions) ~ "Formula", - "recipe" %in% names(x$pre$actions) ~ "Recipe", - "variables" %in% names(x$pre$actions) ~ "Variables", - TRUE ~ "None" - ) - cli::cli_text("{.emph Preprocessor:} {preprocessor}") - - - if ("model" %in% names(x$fit$actions)) { - spec <- class(workflows::extract_spec_parsnip(x))[[1]] - spec <- glue::glue("{spec}()") - } else { - spec <- "None" - } - cli::cli_text("{.emph Model:} {spec}") - - postprocessor <- ifelse(has_postprocessor_frosting(x), "Frosting", "None") - cli::cli_text("{.emph Postprocessor:} {postprocessor}") - cli::cli_text("") - invisible(x) -} - - -print_preprocessor <- function(x) { - has_preprocessor_formula <- "formula" %in% names(x$pre$actions) - has_preprocessor_recipe <- "recipe" %in% names(x$pre$actions) - has_preprocessor_variables <- "variables" %in% names(x$pre$actions) - - no_preprocessor <- !has_preprocessor_formula && !has_preprocessor_recipe && - !has_preprocessor_variables - - if (no_preprocessor) { - return(invisible(x)) - } - - cli::cli_rule("Preprocessor") - cli::cli_text("") - - if (has_preprocessor_formula) { - print_preprocessor_formula(x) - } - if (has_preprocessor_recipe) { - print_preprocessor_recipe(x) - } - if (has_preprocessor_variables) { - print_preprocessor_variables(x) - } - cli::cli_text("") - invisible(x) -} - -# revision of workflows:::print_model() -print_model <- function(x) { - has_spec <- "model" %in% names(x$fit$actions) - if (!has_spec) { - cli::cli_text("") - return(invisible(x)) - } - has_fit <- !is.null(x$fit$fit) - cli::cli_rule("Model") - - if (has_fit) { - print_fit(x) - cli::cli_text("") - return(invisible(x)) - } - workflows:::print_spec(x) - cli::cli_text("") - invisible(x) -} - - print_postprocessor <- function(x) { if (!has_postprocessor_frosting(x)) { return(invisible(x)) @@ -93,78 +13,6 @@ print_postprocessor <- function(x) { } -# subfunctions for printing ----------------------------------------------- - - - -print_preprocessor_formula <- function(x) { - formula <- workflows::extract_preprocessor(x) - formula <- rlang::expr_text(formula) - cli::cli_text(formula) - invisible(x) -} - -print_preprocessor_variables <- function(x) { - variables <- workflows::extract_preprocessor(x) - outcomes <- rlang::quo_get_expr(variables$outcomes) - predictors <- rlang::quo_get_expr(variables$predictors) - outcomes <- rlang::expr_text(outcomes) - predictors <- rlang::expr_text(predictors) - cli::cli_text("Outcomes: ", outcomes) - cli::cli_text("") - cli::cli_text("Predictors: ", predictors) - invisible(x) -} - -# Currently only used in the workflow printing -print_preprocessor_recipe <- function(x, ...) { - recipe <- workflows::extract_preprocessor(x) - steps <- recipe$steps - n_steps <- length(steps) - cli::cli_text("{n_steps} Recipe step{?s}.") - - if (n_steps == 0L) { - return(invisible(x)) - } - - step_names <- map_chr(steps, ~ glue::glue("{class(.x)[[1]]}()")) - - if (n_steps <= 10L) { - cli::cli_ol(step_names) - return(invisible(x)) - } - - extra_steps <- n_steps - 10L - step_names <- step_names[1:10] - - cli::cli_ol(step_names) - cli::cli_bullets("... and {extra_steps} more step{?s}.") - invisible(x) -} - - - - -print_fit <- function(x) { - parsnip_fit <- workflows::extract_fit_parsnip(x) - fit <- parsnip_fit$fit - output <- utils::capture.output(fit) - n_output <- length(output) - if (n_output < 50L) { - print(fit) - return(invisible(x)) - } - n_extra_output <- n_output - 50L - output <- output[1:50] - empty_string <- output == "" - output[empty_string] <- " " - - cli::cli_verbatim(output) - cli::cli_text("") - cli::cli_text("... and {n_extra_output} more line{?s}.") - invisible(x) -} - # Currently only used in the workflow printing print_frosting <- function(x, ...) { layers <- x$layers From a945ba7dff7d2ba636fad9ba2d0265a284dffe75 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Mon, 23 Sep 2024 09:39:38 -0700 Subject: [PATCH 12/13] refactor update.layer to remove recipes::: --- R/layers.R | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/R/layers.R b/R/layers.R index ee26e63f9..09ed307a1 100644 --- a/R/layers.R +++ b/R/layers.R @@ -74,15 +74,28 @@ layer <- function(subclass, ..., .prefix = "layer_") { #' p1 #' @export update.layer <- function(object, ...) { - changes <- list(...) + changes <- enlist(...) # Replace the appropriate values in object with the changes - object <- recipes:::update_fields(object, changes) + object <- update_layers(object, changes) # Call layer() to construct a new layer to ensure all new changes are validated reconstruct_layer(object) } +update_layers <- function(object, changes) { + new_nms <- names(changes) + old_nms <- names(object) + layer_type <- class(object)[1] + for (nm in new_nms) { + if (!(nm %in% old_nms)) { + cli::cli_abort("The layer you are trying to update, {.fn {layer_type}}, \\\n does not have the {.field {nm}} field.") + } + object[[nm]] <- changes[[nm]] + } + object +} + reconstruct_layer <- function(x) { # Collect the subclass of the layer to use # when recreating it From 6ad8f27d4865408d33998f195782b19ede392179 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Mon, 23 Sep 2024 10:08:31 -0700 Subject: [PATCH 13/13] remove unusable layer creation utility (subfuns don't exist anymore) --- R/create-layer.R | 45 ------------------------------------------ inst/templates/layer.R | 44 ----------------------------------------- 2 files changed, 89 deletions(-) delete mode 100644 R/create-layer.R delete mode 100644 inst/templates/layer.R diff --git a/R/create-layer.R b/R/create-layer.R deleted file mode 100644 index 0268a906f..000000000 --- a/R/create-layer.R +++ /dev/null @@ -1,45 +0,0 @@ -#' Create a new layer -#' -#' This function creates the skeleton for a new `frosting` layer. When called -#' inside a package, it will create an R script in the `R/` directory, -#' fill in the name of the layer, and open the file. -#' -#' @inheritParams usethis::use_test -#' -#' @importFrom rlang %||% -#' @noRd -#' @keywords internal -#' @examples -#' \dontrun{ -#' -#' # Note: running this will write `layer_strawberry.R` to -#' # the `R/` directory of your current project -#' create_layer("strawberry") -#' } -#' -create_layer <- function(name = NULL, open = rlang::is_interactive()) { - name <- name %||% usethis:::get_active_r_file(path = "R") - if (substr(name, 1, 5) == "layer") { - nn <- substring(name, 6) - if (substr(nn, 1, 1) == "_") nn <- substring(nn, 2) - cli::cli_abort( - c('`name` should not begin with "layer" or "layer_".', - i = 'Did you mean to use `create_layer("{ nn }")`?' - ) - ) - } - layer_name <- name - name <- paste0("layer_", name) - name <- usethis:::slug(name, "R") - usethis:::check_file_name(name) - path <- fs::path("R", name) - if (!fs::file_exists(path)) { - usethis::use_template( - "layer.R", - save_as = path, - data = list(name = layer_name), open = FALSE, - package = "epipredict" - ) - } - usethis::edit_file(usethis::proj_path(path), open = open) -} diff --git a/inst/templates/layer.R b/inst/templates/layer.R deleted file mode 100644 index 59556db5f..000000000 --- a/inst/templates/layer.R +++ /dev/null @@ -1,44 +0,0 @@ -layer_{{ name }} <- function(frosting, # mandatory - ..., - args, # add as many as you need - more_args, - id = rand_id("{{{ name }}}")) { - - # validate any additional arguments here - - # if you don't need ... then uncomment the line below - ## rlang::check_dots_empty() - add_layer( - frosting, - layer_{{{ name }}}_new( - terms = dplyr::enquos(...), # remove if ... should be empty - args, - id = id - ) - ) -} - -layer_{{{ name }}}_new <- function(terms, args, more_args, id) { - layer("{{{ name }}}", - terms = terms, - args = args, - more_args = more_args, - id = id) -} - -#' @export -slather.layer_{{{ name }}} <- - function(object, components, workflow, new_data, ...) { - rlang::check_dots_empty() - - # if layer_ used ... in tidyselect, we need to evaluate it now - exprs <- rlang::expr(c(!!!object$terms)) - pos <- tidyselect::eval_select(exprs, components$predictions) - col_names <- names(pos) - # now can select with `tidyselect::all_of(col_names)` - - # add additional necessary processing steps here - - # always return components - components - }