Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify workflow #399

Draft
wants to merge 17 commits into
base: v0.2.0
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ S3method(Remove_model,epi_workflow)
S3method(Remove_model,workflow)
S3method(Update_model,epi_workflow)
S3method(Update_model,workflow)
S3method(add_frosting,default)
S3method(add_frosting,epi_workflow)
S3method(adjust_epi_recipe,epi_recipe)
S3method(adjust_epi_recipe,epi_workflow)
S3method(adjust_frosting,epi_workflow)
Expand Down Expand Up @@ -97,6 +99,8 @@ S3method(quantile,dist_quantiles)
S3method(recipe,epi_df)
S3method(recipes::recipe,formula)
S3method(refresh_blueprint,default_epi_recipe_blueprint)
S3method(remove_frosting,default)
S3method(remove_frosting,epi_workflow)
S3method(residuals,flatline)
S3method(run_mold,default_epi_recipe_blueprint)
S3method(slather,layer_add_forecast_date)
Expand All @@ -119,6 +123,8 @@ S3method(tidy,check_enough_train_data)
S3method(tidy,frosting)
S3method(tidy,layer)
S3method(update,layer)
S3method(update_frosting,default)
S3method(update_frosting,epi_workflow)
S3method(vec_ptype_abbr,dist_quantiles)
S3method(vec_ptype_full,dist_quantiles)
S3method(weighted_interval_score,default)
Expand Down Expand Up @@ -271,6 +277,7 @@ importFrom(rlang,":=")
importFrom(rlang,abort)
importFrom(rlang,arg_match)
importFrom(rlang,as_function)
importFrom(rlang,caller_arg)
importFrom(rlang,caller_env)
importFrom(rlang,enquo)
importFrom(rlang,enquos)
Expand Down
45 changes: 0 additions & 45 deletions R/create-layer.R

This file was deleted.

151 changes: 29 additions & 122 deletions R/epi_recipe.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,23 +95,12 @@ add_epi_recipe <- function(
#' @rdname add_epi_recipe
#' @export
remove_epi_recipe <- function(x) {
workflows:::validate_is_workflow(x)

if (!workflows:::has_preprocessor_recipe(x)) {
rlang::warn("The workflow has no recipe preprocessor to remove.")
}

actions <- x$pre$actions
actions[["recipe"]] <- NULL

new_epi_workflow(
pre = workflows:::new_stage_pre(actions = actions),
fit = x$fit,
post = x$post,
trained = FALSE
)
x <- workflows::remove_recipe(x)
class(x) <- c("epi_workflow", class(x))
x
}


#' @rdname add_epi_recipe
#' @export
update_epi_recipe <- function(x, recipe, ..., blueprint = default_epi_recipe_blueprint()) {
Expand Down Expand Up @@ -180,15 +169,21 @@ adjust_epi_recipe <- function(x, which_step, ..., blueprint = default_epi_recipe

#' @rdname adjust_epi_recipe
#' @export
adjust_epi_recipe.epi_workflow <- function(x, which_step, ..., blueprint = default_epi_recipe_blueprint()) {
recipe <- adjust_epi_recipe(workflows::extract_preprocessor(x), which_step, ...)
adjust_epi_recipe.epi_workflow <- function(
x, which_step, ..., blueprint = default_epi_recipe_blueprint()
) {

update_epi_recipe(x, recipe, blueprint = blueprint)
rec <- adjust_epi_recipe(
workflows::extract_preprocessor(x), which_step, ...
)
update_epi_recipe(x, rec, blueprint = blueprint)
}

#' @rdname adjust_epi_recipe
#' @export
adjust_epi_recipe.epi_recipe <- function(x, which_step, ..., blueprint = default_epi_recipe_blueprint()) {
adjust_epi_recipe.epi_recipe <- function(
x, which_step, ..., blueprint = default_epi_recipe_blueprint()
) {
if (!(is.numeric(which_step) || is.character(which_step))) {
cli::cli_abort(
c("`which_step` must be a number or a character.",
Expand Down Expand Up @@ -294,109 +289,21 @@ kill_levels <- function(x, keys) {

#' @export
print.epi_recipe <- function(x, form_width = 30, ...) {
cli::cli_div(theme = list(.pkg = list("vec-trunc" = Inf, "vec-last" = ", ")))

cli::cli_h1("Epi Recipe")
cli::cli_h3("Inputs")

tab <- table(x$var_info$role, useNA = "ifany")
tab <- stats::setNames(tab, names(tab))
names(tab)[is.na(names(tab))] <- "undeclared role"

roles <- c("outcome", "predictor", "case_weights", "undeclared role")

tab <- c(
tab[names(tab) == roles[1]],
tab[names(tab) == roles[2]],
tab[names(tab) == roles[3]],
sort(tab[!names(tab) %in% roles], TRUE),
tab[names(tab) == roles[4]]
)

cli::cli_text("Number of variables by role")

spaces_needed <- max(nchar(names(tab))) - nchar(names(tab)) +
max(nchar(tab)) - nchar(tab)

cli::cli_verbatim(
glue::glue("{names(tab)}: {strrep('\ua0', spaces_needed)}{tab}")
)

if ("tr_info" %in% names(x)) {
cli::cli_h3("Training information")
nmiss <- x$tr_info$nrows - x$tr_info$ncomplete
nrows <- x$tr_info$nrows

cli::cli_text(
"Training data contained {nrows} data points and {cli::no(nmiss)} \\
incomplete row{?s}."
)
}

if (!is.null(x$steps)) {
cli::cli_h3("Operations")
}

fmt <- cli::cli_fmt({
for (step in x$steps) {
print(step, form_width = form_width)
}
})
cli::cli_ol(fmt)
cli::cli_end()

invisible(x)
}

# Currently only used in the workflow printing
print_preprocessor_recipe <- function(x, ...) {
recipe <- workflows::extract_preprocessor(x)
steps <- recipe$steps
n_steps <- length(steps)
cli::cli_text("{n_steps} Recipe step{?s}.")

if (n_steps == 0L) {
return(invisible(x))
}

step_names <- map_chr(steps, workflows:::pull_step_name)

if (n_steps <= 10L) {
cli::cli_ol(step_names)
return(invisible(x))
}

extra_steps <- n_steps - 10L
step_names <- step_names[1:10]

cli::cli_ol(step_names)
cli::cli_bullets("... and {extra_steps} more step{?s}.")
invisible(x)
}

print_preprocessor <- function(x) {
has_preprocessor_formula <- workflows:::has_preprocessor_formula(x)
has_preprocessor_recipe <- workflows:::has_preprocessor_recipe(x)
has_preprocessor_variables <- workflows:::has_preprocessor_variables(x)

no_preprocessor <- !has_preprocessor_formula && !has_preprocessor_recipe &&
!has_preprocessor_variables

if (no_preprocessor) {
return(invisible(x))
}

cli::cli_rule("Preprocessor")
cli::cli_text("")

if (has_preprocessor_formula) {
workflows:::print_preprocessor_formula(x)
}
if (has_preprocessor_recipe) {
print_preprocessor_recipe(x)
}
if (has_preprocessor_variables) {
workflows:::print_preprocessor_variables(x)
o <- cli::cli_fmt(NextMethod())
# Fix up the recipe name
rr <- unlist(strsplit(o[2], "Recipe"))
len <- nchar(rr[2])
h1_tail <- paste0(substr(rr[2], 1, len / 2 - 10), substr(rr[2], len / 2, len))
o[2] <- paste0(rr[1], "Epi Recipe", h1_tail)

# Number the operations
has_operations <- any(grepl(" Operations ", o, fixed = TRUE))
if (has_operations) {
ops <- seq(grep(" Operations ", o, fixed = TRUE) + 1, length(o))
# kills the \bullet
rep_ops <- sub("^\\033\\[36m.\\033\\[39m ", "", o[ops], perl = TRUE)
o[ops] <- paste0(ops - ops[1] + 1, ". ", rep_ops)
}
cli::cli_bullets(o)
invisible(x)
}
38 changes: 13 additions & 25 deletions R/epi_workflow.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,15 @@
#'
#' wf
epi_workflow <- function(preprocessor = NULL, spec = NULL, postprocessor = NULL) {
out <- workflows::workflow(spec = spec)
class(out) <- c("epi_workflow", class(out))

out <- workflows::workflow(preprocessor, spec = spec)
if (is_epi_recipe(preprocessor)) {
out <- workflows::remove_recipe(out)
out <- add_epi_recipe(out, preprocessor)
} else if (!is_null(preprocessor)) {
out <- workflows:::add_preprocessor(out, preprocessor)
}
class(out) <- c("epi_workflow", class(out))
if (!is_null(postprocessor)) {
out <- add_postprocessor(out, postprocessor)
}

out
}

Expand Down Expand Up @@ -101,7 +98,6 @@ fit.epi_workflow <- function(object, data, ..., control = workflows::control_wor
as_of = attributes(data)$metadata$as_of
)
object$original_data <- data

NextMethod()
}

Expand Down Expand Up @@ -162,11 +158,14 @@ predict.epi_workflow <- function(object, new_data, type = NULL, opts = list(), .
}
components <- list()
components$mold <- workflows::extract_mold(object)
components$forged <- hardhat::forge(new_data,
components$forged <- hardhat::forge(
new_data,
blueprint = components$mold$blueprint
)
components$keys <- grab_forged_keys(components$forged, object, new_data)
components <- apply_frosting(object, components, new_data, type = type, opts = opts, ...)
components <- apply_frosting(
object, components, new_data, type = type, opts = opts, ...
)
components$predictions
}

Expand Down Expand Up @@ -201,25 +200,14 @@ augment.epi_workflow <- function(x, new_data, ...) {
full_join(predictions, new_data, by = join_by)
}

new_epi_workflow <- function(
pre = workflows:::new_stage_pre(),
fit = workflows:::new_stage_fit(),
post = workflows:::new_stage_post(),
trained = FALSE) {
out <- workflows:::new_workflow(
pre = pre, fit = fit, post = post, trained = trained
)
class(out) <- c("epi_workflow", class(out))
out
}


#' @export
print.epi_workflow <- function(x, ...) {
print_header(x)
print_preprocessor(x)
# workflows:::print_case_weights(x)
print_model(x)
trained <- ifelse(workflows::is_trained_workflow(x), " [trained]", "")
header <- glue::glue("Epi Workflow{trained}")
txt <- utils::capture.output(NextMethod())
txt[1] <- cli::rule(header, line = 2)
cli::cat_line(txt)
print_postprocessor(x)
invisible(x)
}
Expand Down
2 changes: 1 addition & 1 deletion R/extract.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ extract_argument.epi_workflow <- function(x, name, arg, ...) {
rlang::check_dots_empty()
type <- sub("_.*", "", name)
if (type %in% c("check", "step")) {
if (!workflows:::has_preprocessor_recipe(x)) {
if ("recipe" %nin% names(x$pre$actions)) {
cli_abort("The workflow must have a recipe preprocessor.")
}
out <- extract_argument(x$pre$actions$recipe$recipe, name, arg)
Expand Down
Loading
Loading