diff --git a/DESCRIPTION b/DESCRIPTION index 02f0252..8df9225 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -4,7 +4,7 @@ Version: 1.0.2.9000 Authors@R: c( person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre"), comment = c(ORCID = "0000-0003-2402-136X")), - person(given = "Posit Software, PBC", role = c("cph", "fnd")) + person("Posit Software, PBC", role = c("cph", "fnd")) ) Description: Tree- and rule-based models can be bagged () using this package and their predictions @@ -28,7 +28,7 @@ Imports: hardhat (>= 1.1.0), magrittr, purrr, - rlang, + rlang (>= 1.1.0), rpart, rsample, tibble, @@ -36,9 +36,8 @@ Imports: utils, withr Suggests: - AmesHousing, covr, - earth, + earth, modeldata, nnet, recipes, @@ -48,7 +47,31 @@ Suggests: yardstick Config/Needs/website: tidyverse/tidytemplate Config/testthat/edition: 3 +Config/usethis/last-upkeep: 2024-10-23 Encoding: UTF-8 Language: en-US Roxygen: list(markdown = TRUE) RoxygenNote: 7.3.2 +Collate: + 'C5.0.R' + 'bag_mars_data.R' + 'bag_nnet_data.R' + 'bag_tree_data.R' + 'import-standalone-types-check.R' + 'validate.R' + 'bagger.R' + 'baguette-package.R' + 'bridge.R' + 'cart.R' + 'class_cost.R' + 'constructor.R' + 'cost_models.R' + 'import-standalone-obj-type.R' + 'mars.R' + 'misc.R' + 'model_info.R' + 'nnet.R' + 'out-of-bag.R' + 'predict.R' + 'var_imp.R' + 'zzz.R' diff --git a/R/C5.0.R b/R/C5.0.R index 7c7b651..9ded28e 100644 --- a/R/C5.0.R +++ b/R/C5.0.R @@ -32,7 +32,7 @@ c5_bagger <- function(rs, control, ...) { if (control$reduce) { rs <- rs %>% - mutate(model = map(model, axe_C5)) + dplyr::mutate(model = purrr::map(model, axe_C5)) } list(model = rs, imp = imps) diff --git a/R/aaa_validate.R b/R/aaa_validate.R deleted file mode 100644 index bb54465..0000000 --- a/R/aaa_validate.R +++ /dev/null @@ -1,160 +0,0 @@ -validate_args <- function(model, times, control, cost) { - if (!is.character(model) || length(model) != 1) { - cli::cli_abort("`base_model` should be a single character value.") - } - if (!(model %in% baguette_models)) { - msg <- paste( - "`base_model` should be one of ", - paste0("'", baguette_models, "'", collapse = ", ") - ) - cli::cli_abort(msg) - } - - # ---------------------------------------------------------------------------- - - if (!is.null(cost) & !(model %in% c("CART", "C5.0"))) { - cli::cli_abort("`base_model` should be either 'CART' or 'C5.0'") - } - if (!is.null(cost)) { - if (is.numeric(cost) && any(cost < 0)) { - cli::cli_abort("`cost` should be non-negative.") - } - } - - # ---------------------------------------------------------------------------- - - if (!is.integer(times)) { - cli::cli_abort("`times` must be an integer > 1.") - } - if (times < 1) { - cli::cli_abort("`times` must be an integer > 1.") - } - - # ---------------------------------------------------------------------------- - - validate_control(control) - - # ---------------------------------------------------------------------------- - - invisible(TRUE) - -} - -integer_B <- function(B) { - if (is.numeric(B) & !is.integer(B)) { - B <- as.integer(B) - } - B -} - -# ------------------------------------------------------------------------------ - -validate_y_type <- function(base_model, outcomes) { - hardhat::validate_outcomes_are_univariate(outcomes) - - if (base_model == "C5.0") { - hardhat::validate_outcomes_are_factors(outcomes) - } - -} - -# ------------------------------------------------------------------------------ - -model_failure <- function(x) { - if (inherits(x, "model_fit")) { - res <- inherits(x$fit, "try-error") - } else { - res <- inherits(x, "try-error") - } - res -} - -check_for_disaster <- function(x) { - x <- dplyr::mutate(x, passed = !purrr::map_lgl(model, model_failure)) - - if (sum(x$passed) == 0) { - if (inherits(x$model[[1]], "try-error")) { - msg <- as.character(x$model[[1]]) - } else { - if (inherits(x$model[[1]], "model_fit")) { - msg <- as.character(x$model[[1]]$fit) - } else msg <- NA - } - - if (!is.na(msg)) { - msg <- paste0("An example message was:\n ", msg) - } else msg <- "" - - - cli::cli_abort(paste0("All of the models failed. ", msg)) - } - x -} - -# ------------------------------------------------------------------------------ - -check_type <- function(object, type) { - if (is.null(type)) { - if (object$base_model[2] == "classification") { - type <- "class" - } else { - type <- "numeric" - } - } else { - if (object$base_model[2] == "classification") { - if (!(type %in% c("class", "prob"))) - cli::cli_abort("`type` should be either 'class' or 'prob'") - } else { - if (type != "numeric") - cli::cli_abort("`type` should be 'numeric'") - } - } - type -} - -validate_importance <- function(x) { - if (is.null(x)) { - return(x) - } - - if (!is_tibble(x)) { - cli::cli_abort("Imprtance score results should be a tibble.") - } - - exp_cols <- c("term", "value", "std.error", "used") - if (!isTRUE(all.equal(exp_cols, names(x)))) { - msg <- paste0("Importance columns should be: ", - paste0("'", exp_cols, "'", collapse = ", "), - "." - ) - cli::cli_abort(msg) - } - x -} - -# ------------------------------------------------------------------------------ - -validate_control <- function(x) { - if (!is.list(x)) { - cli::cli_abort("The control object should be a list created by `control_bag()`.") - } - samps <- c("none", "down") - - if (length(x$var_imp) != 1 || !is.logical(x$var_imp)) { - cli::cli_abort("`var_imp` should be a single logical value.") - } - if (length(x$allow_parallel) != 1 || !is.logical(x$allow_parallel)) { - cli::cli_abort("`allow_parallel` should be a single logical value.") - } - if (length(x$sampling) != 1 || !is.character(x$sampling) || !any(samps == x$sampling)) { - cli::cli_abort("`sampling` should be either 'none' or 'down'.") - } - if (length(x$reduce) != 1 || !is.logical(x$reduce)) { - cli::cli_abort("`reduce` should be a single logical value.") - } - if (!is.null(x$extract) && !is.function(x$extract)) { - cli::cli_abort("`extract` should be NULL or a function.") - } - - x -} diff --git a/R/bag_mars_data.R b/R/bag_mars_data.R index 37aa68f..f1f450a 100644 --- a/R/bag_mars_data.R +++ b/R/bag_mars_data.R @@ -3,7 +3,7 @@ # they are already in the parsnip model database. We'll exclude them from # coverage stats for this reason. -# nocov +# nocov start make_bag_mars <- function() { diff --git a/R/bag_nnet_data.R b/R/bag_nnet_data.R index 5605136..2f21f7e 100644 --- a/R/bag_nnet_data.R +++ b/R/bag_nnet_data.R @@ -3,7 +3,7 @@ # they are already in the parsnip model database. We'll exclude them from # coverage stats for this reason. -# nocov +# nocov start make_bag_mlp <- function() { diff --git a/R/bag_tree_data.R b/R/bag_tree_data.R index 1e476b6..a140667 100644 --- a/R/bag_tree_data.R +++ b/R/bag_tree_data.R @@ -3,7 +3,7 @@ # they are already in the parsnip model database. We'll exclude them from # coverage stats for this reason. -# nocov +# nocov start make_bag_tree <- function() { diff --git a/R/bagger.R b/R/bagger.R index 797c223..aaac3b0 100644 --- a/R/bagger.R +++ b/R/bagger.R @@ -19,7 +19,7 @@ #' @param times A single integer greater than 1 for the maximum number of bootstrap #' samples/ensemble members (some model fits might fail). #' @param control A list of options generated by `control_bag()`. -#' @param cost A non-negative scale (for two class problems) or a cost matrix. +#' @param cost A non-negative scale (for two class problems) or a square cost matrix. #' @param ... Optional arguments to pass to the base model function. #' @details `bagger()` fits separate models to bootstrap samples. The #' prediction function for each model object is encoded in an R expression and @@ -89,6 +89,7 @@ #' cart_pca_bag #' } #' @export +#' @include validate.R bagger <- function(x, ...) { UseMethod("bagger") } diff --git a/R/aaa.R b/R/baguette-package.R similarity index 69% rename from R/aaa.R rename to R/baguette-package.R index dc2d4d4..c7776e4 100644 --- a/R/aaa.R +++ b/R/baguette-package.R @@ -1,7 +1,10 @@ +#' @keywords internal +"_PACKAGE" + +## usethis namespace: start #' @import rlang #' @import dplyr #' @import hardhat -#' #' @importFrom parsnip set_engine fit fit_xy control_parsnip mars decision_tree #' @importFrom parsnip set_new_model multi_predict update_dot_check show_fit #' @importFrom parsnip new_model_spec null_value update_main_parameters @@ -17,8 +20,20 @@ #' @importFrom withr with_seed #' @importFrom dials new_quant_param #' @importFrom stats coef -#' -# ------------------------------------------------------------------------------ + +#' @keywords internal +"_PACKAGE" + +#' @importFrom magrittr %>% +#' @export +magrittr::`%>%` + +#' @importFrom generics var_imp +#' @export +generics::var_imp + +## usethis namespace: end +NULL utils::globalVariables( c( @@ -42,16 +57,3 @@ utils::globalVariables( ".estimator" ) ) - -# ------------------------------------------------------------------------------ - -# The functions below define the model information. These access the model -# environment inside of parsnip so they have to be executed once parsnip has -# been loaded. - -.onLoad <- function(libname, pkgname) { - # This defines model functions in the parsnip model database - make_bag_tree() - make_bag_mars() - make_bag_mlp() -} diff --git a/R/bridge.R b/R/bridge.R index 0178b9d..9283336 100644 --- a/R/bridge.R +++ b/R/bridge.R @@ -1,4 +1,5 @@ -bagger_bridge <- function(processed, weights, base_model, seed, times, control, cost, ...) { +bagger_bridge <- function(processed, weights, base_model, seed, times, control, + cost, ..., call = rlang::caller_env()) { validate_outcomes_are_univariate(processed$outcomes) if (base_model %in% c("C5.0")) { validate_outcomes_are_factors(processed$outcomes) @@ -27,8 +28,8 @@ bagger_bridge <- function(processed, weights, base_model, seed, times, control, } else { res <- switch( base_model, - CART = cost_sens_cart_bagger(rs, control, cost, ...), - C5.0 = cost_sens_c5_bagger(rs, control, cost, ...) + CART = cost_sens_cart_bagger(rs, control, cost, ..., call = call), + C5.0 = cost_sens_c5_bagger(rs, control, cost, ..., call = call) ) } @@ -43,16 +44,3 @@ bagger_bridge <- function(processed, weights, base_model, seed, times, control, ) res } - -validate_case_weights <- function(weights, data) { - if (is.null(weights)) { - return(invisible(NULL)) - } - n <- nrow(data) - if (!is.vector(weights) || !is.numeric(weights) || length(weights) != n || - any(weights < 0)) { - cli::cli_abort("'weights' should be a non-negative numeric vector with the same size as the data.") - } - invisible(NULL) -} - diff --git a/R/cart.R b/R/cart.R index c9aa131..58a0e4a 100644 --- a/R/cart.R +++ b/R/cart.R @@ -34,7 +34,7 @@ cart_bagger <- function(rs, control, ...) { if (control$reduce) { rs <- rs %>% - mutate(model = map(model, axe_cart)) + dplyr::mutate(model = purrr::map(model, axe_cart)) } list(model = rs, imp = imps) diff --git a/R/constructor.R b/R/constructor.R index 1856203..9c2fc48 100644 --- a/R/constructor.R +++ b/R/constructor.R @@ -1,9 +1,11 @@ -new_bagger <- function(model_df, imp, control, cost, base_model, blueprint) { +new_bagger <- function(model_df, imp, control, cost, base_model, blueprint, + call = rlang::caller_env()) { if (!is_tibble(model_df)) { - cli::cli_abort("`model_df` should be a tibble.") + cli::cli_abort("{.arg model_df} should be {.cls tibble}.", call = call) } + # TODO extend to use mode from model object(s) if (is.numeric(blueprint$ptypes$outcomes[[1]])) { mod_mode <- "regression" } else { diff --git a/R/cost_models.R b/R/cost_models.R index 436d545..eabb19c 100644 --- a/R/cost_models.R +++ b/R/cost_models.R @@ -1,9 +1,10 @@ -cost_matrix <- function(x, lvl, truth_is_row = TRUE) { +cost_matrix <- function(x, lvl, truth_is_row = TRUE, call = rlang::caller_env()) { if (is.matrix(x)) { } else { if (length(lvl) != 2) { - cli::cli_abort("`cost` can only be a scalar when there are two levels.") + cli::cli_abort("{.arg cost} can only be a scalar when there are two + levels.", call = call) } else { x0 <- x x <- matrix(1, ncol = 2, nrow = 2) @@ -19,14 +20,14 @@ cost_matrix <- function(x, lvl, truth_is_row = TRUE) { x } -cost_sens_cart_bagger <- function(rs, control, cost, ...) { +cost_sens_cart_bagger <- function(rs, control, cost, ..., call = rlang::caller_env()) { # capture dots opt <- rlang::dots_list(...) nms <- names(opt) lvl <- levels(rs$splits[[1]]$data$.outcome) - cost <- cost_matrix(cost, lvl) + cost <- cost_matrix(cost, lvl, call = call) # Attach cost matrix to parms = list(loss) but first # check existing options passed by user for loss @@ -41,14 +42,14 @@ cost_sens_cart_bagger <- function(rs, control, cost, ...) { -cost_sens_c5_bagger <- function(rs, control, cost, ...) { +cost_sens_c5_bagger <- function(rs, control, cost, ..., call = rlang::caller_env()) { # capture dots opt <- rlang::dots_list(...) nms <- names(opt) lvl <- levels(rs$splits[[1]]$data$.outcome) - cost <- cost_matrix(cost, lvl, truth_is_row = FALSE) + cost <- cost_matrix(cost, lvl, truth_is_row = FALSE, call = call) # Attach cost matrix to options opt$costs <- cost diff --git a/R/import-standalone-obj-type.R b/R/import-standalone-obj-type.R new file mode 100644 index 0000000..646aa33 --- /dev/null +++ b/R/import-standalone-obj-type.R @@ -0,0 +1,363 @@ +# Standalone file: do not edit by hand +# Source: +# ---------------------------------------------------------------------- +# +# --- +# repo: r-lib/rlang +# file: standalone-obj-type.R +# last-updated: 2024-02-14 +# license: https://unlicense.org +# imports: rlang (>= 1.1.0) +# --- +# +# ## Changelog +# +# 2024-02-14: +# - `obj_type_friendly()` now works for S7 objects. +# +# 2023-05-01: +# - `obj_type_friendly()` now only displays the first class of S3 objects. +# +# 2023-03-30: +# - `stop_input_type()` now handles `I()` input literally in `arg`. +# +# 2022-10-04: +# - `obj_type_friendly(value = TRUE)` now shows numeric scalars +# literally. +# - `stop_friendly_type()` now takes `show_value`, passed to +# `obj_type_friendly()` as the `value` argument. +# +# 2022-10-03: +# - Added `allow_na` and `allow_null` arguments. +# - `NULL` is now backticked. +# - Better friendly type for infinities and `NaN`. +# +# 2022-09-16: +# - Unprefixed usage of rlang functions with `rlang::` to +# avoid onLoad issues when called from rlang (#1482). +# +# 2022-08-11: +# - Prefixed usage of rlang functions with `rlang::`. +# +# 2022-06-22: +# - `friendly_type_of()` is now `obj_type_friendly()`. +# - Added `obj_type_oo()`. +# +# 2021-12-20: +# - Added support for scalar values and empty vectors. +# - Added `stop_input_type()` +# +# 2021-06-30: +# - Added support for missing arguments. +# +# 2021-04-19: +# - Added support for matrices and arrays (#141). +# - Added documentation. +# - Added changelog. +# +# nocov start + +#' Return English-friendly type +#' @param x Any R object. +#' @param value Whether to describe the value of `x`. Special values +#' like `NA` or `""` are always described. +#' @param length Whether to mention the length of vectors and lists. +#' @return A string describing the type. Starts with an indefinite +#' article, e.g. "an integer vector". +#' @noRd +obj_type_friendly <- function(x, value = TRUE) { + if (is_missing(x)) { + return("absent") + } + + if (is.object(x)) { + if (inherits(x, "quosure")) { + type <- "quosure" + } else { + type <- class(x)[[1L]] + } + return(sprintf("a <%s> object", type)) + } + + if (!is_vector(x)) { + return(.rlang_as_friendly_type(typeof(x))) + } + + n_dim <- length(dim(x)) + + if (!n_dim) { + if (!is_list(x) && length(x) == 1) { + if (is_na(x)) { + return(switch( + typeof(x), + logical = "`NA`", + integer = "an integer `NA`", + double = + if (is.nan(x)) { + "`NaN`" + } else { + "a numeric `NA`" + }, + complex = "a complex `NA`", + character = "a character `NA`", + .rlang_stop_unexpected_typeof(x) + )) + } + + show_infinites <- function(x) { + if (x > 0) { + "`Inf`" + } else { + "`-Inf`" + } + } + str_encode <- function(x, width = 30, ...) { + if (nchar(x) > width) { + x <- substr(x, 1, width - 3) + x <- paste0(x, "...") + } + encodeString(x, ...) + } + + if (value) { + if (is.numeric(x) && is.infinite(x)) { + return(show_infinites(x)) + } + + if (is.numeric(x) || is.complex(x)) { + number <- as.character(round(x, 2)) + what <- if (is.complex(x)) "the complex number" else "the number" + return(paste(what, number)) + } + + return(switch( + typeof(x), + logical = if (x) "`TRUE`" else "`FALSE`", + character = { + what <- if (nzchar(x)) "the string" else "the empty string" + paste(what, str_encode(x, quote = "\"")) + }, + raw = paste("the raw value", as.character(x)), + .rlang_stop_unexpected_typeof(x) + )) + } + + return(switch( + typeof(x), + logical = "a logical value", + integer = "an integer", + double = if (is.infinite(x)) show_infinites(x) else "a number", + complex = "a complex number", + character = if (nzchar(x)) "a string" else "\"\"", + raw = "a raw value", + .rlang_stop_unexpected_typeof(x) + )) + } + + if (length(x) == 0) { + return(switch( + typeof(x), + logical = "an empty logical vector", + integer = "an empty integer vector", + double = "an empty numeric vector", + complex = "an empty complex vector", + character = "an empty character vector", + raw = "an empty raw vector", + list = "an empty list", + .rlang_stop_unexpected_typeof(x) + )) + } + } + + vec_type_friendly(x) +} + +vec_type_friendly <- function(x, length = FALSE) { + if (!is_vector(x)) { + abort("`x` must be a vector.") + } + type <- typeof(x) + n_dim <- length(dim(x)) + + add_length <- function(type) { + if (length && !n_dim) { + paste0(type, sprintf(" of length %s", length(x))) + } else { + type + } + } + + if (type == "list") { + if (n_dim < 2) { + return(add_length("a list")) + } else if (is.data.frame(x)) { + return("a data frame") + } else if (n_dim == 2) { + return("a list matrix") + } else { + return("a list array") + } + } + + type <- switch( + type, + logical = "a logical %s", + integer = "an integer %s", + numeric = , + double = "a double %s", + complex = "a complex %s", + character = "a character %s", + raw = "a raw %s", + type = paste0("a ", type, " %s") + ) + + if (n_dim < 2) { + kind <- "vector" + } else if (n_dim == 2) { + kind <- "matrix" + } else { + kind <- "array" + } + out <- sprintf(type, kind) + + if (n_dim >= 2) { + out + } else { + add_length(out) + } +} + +.rlang_as_friendly_type <- function(type) { + switch( + type, + + list = "a list", + + NULL = "`NULL`", + environment = "an environment", + externalptr = "a pointer", + weakref = "a weak reference", + S4 = "an S4 object", + + name = , + symbol = "a symbol", + language = "a call", + pairlist = "a pairlist node", + expression = "an expression vector", + + char = "an internal string", + promise = "an internal promise", + ... = "an internal dots object", + any = "an internal `any` object", + bytecode = "an internal bytecode object", + + primitive = , + builtin = , + special = "a primitive function", + closure = "a function", + + type + ) +} + +.rlang_stop_unexpected_typeof <- function(x, call = caller_env()) { + abort( + sprintf("Unexpected type <%s>.", typeof(x)), + call = call + ) +} + +#' Return OO type +#' @param x Any R object. +#' @return One of `"bare"` (for non-OO objects), `"S3"`, `"S4"`, +#' `"R6"`, or `"S7"`. +#' @noRd +obj_type_oo <- function(x) { + if (!is.object(x)) { + return("bare") + } + + class <- inherits(x, c("R6", "S7_object"), which = TRUE) + + if (class[[1]]) { + "R6" + } else if (class[[2]]) { + "S7" + } else if (isS4(x)) { + "S4" + } else { + "S3" + } +} + +#' @param x The object type which does not conform to `what`. Its +#' `obj_type_friendly()` is taken and mentioned in the error message. +#' @param what The friendly expected type as a string. Can be a +#' character vector of expected types, in which case the error +#' message mentions all of them in an "or" enumeration. +#' @param show_value Passed to `value` argument of `obj_type_friendly()`. +#' @param ... Arguments passed to [abort()]. +#' @inheritParams args_error_context +#' @noRd +stop_input_type <- function(x, + what, + ..., + allow_na = FALSE, + allow_null = FALSE, + show_value = TRUE, + arg = caller_arg(x), + call = caller_env()) { + # From standalone-cli.R + cli <- env_get_list( + nms = c("format_arg", "format_code"), + last = topenv(), + default = function(x) sprintf("`%s`", x), + inherit = TRUE + ) + + if (allow_na) { + what <- c(what, cli$format_code("NA")) + } + if (allow_null) { + what <- c(what, cli$format_code("NULL")) + } + if (length(what)) { + what <- oxford_comma(what) + } + if (inherits(arg, "AsIs")) { + format_arg <- identity + } else { + format_arg <- cli$format_arg + } + + message <- sprintf( + "%s must be %s, not %s.", + format_arg(arg), + what, + obj_type_friendly(x, value = show_value) + ) + + abort(message, ..., call = call, arg = arg) +} + +oxford_comma <- function(chr, sep = ", ", final = "or") { + n <- length(chr) + + if (n < 2) { + return(chr) + } + + head <- chr[seq_len(n - 1)] + last <- chr[n] + + head <- paste(head, collapse = sep) + + # Write a or b. But a, b, or c. + if (n > 2) { + paste0(head, sep, final, " ", last) + } else { + paste0(head, " ", final, " ", last) + } +} + +# nocov end diff --git a/R/import-standalone-types-check.R b/R/import-standalone-types-check.R new file mode 100644 index 0000000..1ca8399 --- /dev/null +++ b/R/import-standalone-types-check.R @@ -0,0 +1,553 @@ +# Standalone file: do not edit by hand +# Source: +# ---------------------------------------------------------------------- +# +# --- +# repo: r-lib/rlang +# file: standalone-types-check.R +# last-updated: 2023-03-13 +# license: https://unlicense.org +# dependencies: standalone-obj-type.R +# imports: rlang (>= 1.1.0) +# --- +# +# ## Changelog +# +# 2024-08-15: +# - `check_character()` gains an `allow_na` argument (@martaalcalde, #1724) +# +# 2023-03-13: +# - Improved error messages of number checkers (@teunbrand) +# - Added `allow_infinite` argument to `check_number_whole()` (@mgirlich). +# - Added `check_data_frame()` (@mgirlich). +# +# 2023-03-07: +# - Added dependency on rlang (>= 1.1.0). +# +# 2023-02-15: +# - Added `check_logical()`. +# +# - `check_bool()`, `check_number_whole()`, and +# `check_number_decimal()` are now implemented in C. +# +# - For efficiency, `check_number_whole()` and +# `check_number_decimal()` now take a `NULL` default for `min` and +# `max`. This makes it possible to bypass unnecessary type-checking +# and comparisons in the default case of no bounds checks. +# +# 2022-10-07: +# - `check_number_whole()` and `_decimal()` no longer treat +# non-numeric types such as factors or dates as numbers. Numeric +# types are detected with `is.numeric()`. +# +# 2022-10-04: +# - Added `check_name()` that forbids the empty string. +# `check_string()` allows the empty string by default. +# +# 2022-09-28: +# - Removed `what` arguments. +# - Added `allow_na` and `allow_null` arguments. +# - Added `allow_decimal` and `allow_infinite` arguments. +# - Improved errors with absent arguments. +# +# +# 2022-09-16: +# - Unprefixed usage of rlang functions with `rlang::` to +# avoid onLoad issues when called from rlang (#1482). +# +# 2022-08-11: +# - Added changelog. +# +# nocov start + +# Scalars ----------------------------------------------------------------- + +.standalone_types_check_dot_call <- .Call + +check_bool <- function(x, + ..., + allow_na = FALSE, + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env()) { + if (!missing(x) && .standalone_types_check_dot_call(ffi_standalone_is_bool_1.0.7, x, allow_na, allow_null)) { + return(invisible(NULL)) + } + + stop_input_type( + x, + c("`TRUE`", "`FALSE`"), + ..., + allow_na = allow_na, + allow_null = allow_null, + arg = arg, + call = call + ) +} + +check_string <- function(x, + ..., + allow_empty = TRUE, + allow_na = FALSE, + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env()) { + if (!missing(x)) { + is_string <- .rlang_check_is_string( + x, + allow_empty = allow_empty, + allow_na = allow_na, + allow_null = allow_null + ) + if (is_string) { + return(invisible(NULL)) + } + } + + stop_input_type( + x, + "a single string", + ..., + allow_na = allow_na, + allow_null = allow_null, + arg = arg, + call = call + ) +} + +.rlang_check_is_string <- function(x, + allow_empty, + allow_na, + allow_null) { + if (is_string(x)) { + if (allow_empty || !is_string(x, "")) { + return(TRUE) + } + } + + if (allow_null && is_null(x)) { + return(TRUE) + } + + if (allow_na && (identical(x, NA) || identical(x, na_chr))) { + return(TRUE) + } + + FALSE +} + +check_name <- function(x, + ..., + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env()) { + if (!missing(x)) { + is_string <- .rlang_check_is_string( + x, + allow_empty = FALSE, + allow_na = FALSE, + allow_null = allow_null + ) + if (is_string) { + return(invisible(NULL)) + } + } + + stop_input_type( + x, + "a valid name", + ..., + allow_na = FALSE, + allow_null = allow_null, + arg = arg, + call = call + ) +} + +IS_NUMBER_true <- 0 +IS_NUMBER_false <- 1 +IS_NUMBER_oob <- 2 + +check_number_decimal <- function(x, + ..., + min = NULL, + max = NULL, + allow_infinite = TRUE, + allow_na = FALSE, + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env()) { + if (missing(x)) { + exit_code <- IS_NUMBER_false + } else if (0 == (exit_code <- .standalone_types_check_dot_call( + ffi_standalone_check_number_1.0.7, + x, + allow_decimal = TRUE, + min, + max, + allow_infinite, + allow_na, + allow_null + ))) { + return(invisible(NULL)) + } + + .stop_not_number( + x, + ..., + exit_code = exit_code, + allow_decimal = TRUE, + min = min, + max = max, + allow_na = allow_na, + allow_null = allow_null, + arg = arg, + call = call + ) +} + +check_number_whole <- function(x, + ..., + min = NULL, + max = NULL, + allow_infinite = FALSE, + allow_na = FALSE, + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env()) { + if (missing(x)) { + exit_code <- IS_NUMBER_false + } else if (0 == (exit_code <- .standalone_types_check_dot_call( + ffi_standalone_check_number_1.0.7, + x, + allow_decimal = FALSE, + min, + max, + allow_infinite, + allow_na, + allow_null + ))) { + return(invisible(NULL)) + } + + .stop_not_number( + x, + ..., + exit_code = exit_code, + allow_decimal = FALSE, + min = min, + max = max, + allow_na = allow_na, + allow_null = allow_null, + arg = arg, + call = call + ) +} + +.stop_not_number <- function(x, + ..., + exit_code, + allow_decimal, + min, + max, + allow_na, + allow_null, + arg, + call) { + if (allow_decimal) { + what <- "a number" + } else { + what <- "a whole number" + } + + if (exit_code == IS_NUMBER_oob) { + min <- min %||% -Inf + max <- max %||% Inf + + if (min > -Inf && max < Inf) { + what <- sprintf("%s between %s and %s", what, min, max) + } else if (x < min) { + what <- sprintf("%s larger than or equal to %s", what, min) + } else if (x > max) { + what <- sprintf("%s smaller than or equal to %s", what, max) + } else { + abort("Unexpected state in OOB check", .internal = TRUE) + } + } + + stop_input_type( + x, + what, + ..., + allow_na = allow_na, + allow_null = allow_null, + arg = arg, + call = call + ) +} + +check_symbol <- function(x, + ..., + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env()) { + if (!missing(x)) { + if (is_symbol(x)) { + return(invisible(NULL)) + } + if (allow_null && is_null(x)) { + return(invisible(NULL)) + } + } + + stop_input_type( + x, + "a symbol", + ..., + allow_na = FALSE, + allow_null = allow_null, + arg = arg, + call = call + ) +} + +check_arg <- function(x, + ..., + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env()) { + if (!missing(x)) { + if (is_symbol(x)) { + return(invisible(NULL)) + } + if (allow_null && is_null(x)) { + return(invisible(NULL)) + } + } + + stop_input_type( + x, + "an argument name", + ..., + allow_na = FALSE, + allow_null = allow_null, + arg = arg, + call = call + ) +} + +check_call <- function(x, + ..., + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env()) { + if (!missing(x)) { + if (is_call(x)) { + return(invisible(NULL)) + } + if (allow_null && is_null(x)) { + return(invisible(NULL)) + } + } + + stop_input_type( + x, + "a defused call", + ..., + allow_na = FALSE, + allow_null = allow_null, + arg = arg, + call = call + ) +} + +check_environment <- function(x, + ..., + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env()) { + if (!missing(x)) { + if (is_environment(x)) { + return(invisible(NULL)) + } + if (allow_null && is_null(x)) { + return(invisible(NULL)) + } + } + + stop_input_type( + x, + "an environment", + ..., + allow_na = FALSE, + allow_null = allow_null, + arg = arg, + call = call + ) +} + +check_function <- function(x, + ..., + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env()) { + if (!missing(x)) { + if (is_function(x)) { + return(invisible(NULL)) + } + if (allow_null && is_null(x)) { + return(invisible(NULL)) + } + } + + stop_input_type( + x, + "a function", + ..., + allow_na = FALSE, + allow_null = allow_null, + arg = arg, + call = call + ) +} + +check_closure <- function(x, + ..., + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env()) { + if (!missing(x)) { + if (is_closure(x)) { + return(invisible(NULL)) + } + if (allow_null && is_null(x)) { + return(invisible(NULL)) + } + } + + stop_input_type( + x, + "an R function", + ..., + allow_na = FALSE, + allow_null = allow_null, + arg = arg, + call = call + ) +} + +check_formula <- function(x, + ..., + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env()) { + if (!missing(x)) { + if (is_formula(x)) { + return(invisible(NULL)) + } + if (allow_null && is_null(x)) { + return(invisible(NULL)) + } + } + + stop_input_type( + x, + "a formula", + ..., + allow_na = FALSE, + allow_null = allow_null, + arg = arg, + call = call + ) +} + + +# Vectors ----------------------------------------------------------------- + +# TODO: Figure out what to do with logical `NA` and `allow_na = TRUE` + +check_character <- function(x, + ..., + allow_na = TRUE, + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env()) { + + if (!missing(x)) { + if (is_character(x)) { + if (!allow_na && any(is.na(x))) { + abort( + sprintf("`%s` can't contain NA values.", arg), + arg = arg, + call = call + ) + } + + return(invisible(NULL)) + } + + if (allow_null && is_null(x)) { + return(invisible(NULL)) + } + } + + stop_input_type( + x, + "a character vector", + ..., + allow_null = allow_null, + arg = arg, + call = call + ) +} + +check_logical <- function(x, + ..., + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env()) { + if (!missing(x)) { + if (is_logical(x)) { + return(invisible(NULL)) + } + if (allow_null && is_null(x)) { + return(invisible(NULL)) + } + } + + stop_input_type( + x, + "a logical vector", + ..., + allow_na = FALSE, + allow_null = allow_null, + arg = arg, + call = call + ) +} + +check_data_frame <- function(x, + ..., + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env()) { + if (!missing(x)) { + if (is.data.frame(x)) { + return(invisible(NULL)) + } + if (allow_null && is_null(x)) { + return(invisible(NULL)) + } + } + + stop_input_type( + x, + "a data frame", + ..., + allow_null = allow_null, + arg = arg, + call = call + ) +} + +# nocov end diff --git a/R/mars.R b/R/mars.R index ff871ea..27afc73 100644 --- a/R/mars.R +++ b/R/mars.R @@ -33,7 +33,7 @@ mars_bagger <- function(rs, control, ...) { if (control$reduce) { rs <- rs %>% - mutate(model = map(model, axe_mars)) + dplyr::mutate(model = purrr::map(model, axe_mars)) } list(model = rs, imp = imps) diff --git a/R/misc.R b/R/misc.R index 6d3c66a..276a6c2 100644 --- a/R/misc.R +++ b/R/misc.R @@ -27,10 +27,10 @@ compute_imp <- function(rs, .fn, compute) { used = length(predictor) ) %>% dplyr::select(-sds) %>% - dplyr::arrange(desc(value)) %>% + dplyr::arrange(dplyr::desc(value)) %>% dplyr::rename(term = predictor) } else { - imps <- tibble( + imps <- tibble::tibble( term = character(0), value = numeric(0), std.error = numeric(0), @@ -48,7 +48,7 @@ compute_imp <- function(rs, .fn, compute) { extractor <- function(rs, extract) { if (!is.null(extract)) { - rs <- rs %>% dplyr::mutate(extras = map(model, ~ extract(.x$fit))) + rs <- rs %>% dplyr::mutate(extras = purrr::map(model, ~ extract(.x$fit))) } rs } @@ -84,9 +84,9 @@ down_sampler <- function(x) { min_n <- min(table(x$.outcome)) x %>% - group_by(.outcome) %>% - sample_n(size = min_n, replace = TRUE) %>% - ungroup() + dplyr::group_by(.outcome) %>% + dplyr::sample_n(size = min_n, replace = TRUE) %>% + dplyr::ungroup() } @@ -112,7 +112,7 @@ replaced <- function(x, replacement) { replace_parsnip_terms <- function(x) { new_terms <- butcher::axe_env(x$model[[1]]$preproc$terms) x <- x %>% - mutate(model = map(model, replaced, replacement = new_terms)) + dplyr::mutate(model = purrr::map(model, replaced, replacement = new_terms)) x } @@ -121,12 +121,9 @@ replace_parsnip_terms <- function(x) { # fix column names (see https://github.com/tidymodels/parsnip/issues/263) fix_column_names <- function(result, object) { - # print("# ------------------------------------------------------------------------------\n") - # print(head(result)) nms <- colnames(result) nms <- gsub(".pred_", "", nms, fixed = TRUE) result <- setNames(result, nms) - # print(head(result)) result } diff --git a/R/nnet.R b/R/nnet.R index 423eaa2..c1531ac 100644 --- a/R/nnet.R +++ b/R/nnet.R @@ -34,7 +34,7 @@ nnet_bagger <- function(rs, control, ...) { if (control$reduce) { rs <- rs %>% - mutate(model = map(model, axe_nnet)) + dplyr::mutate(model = purrr::map(model, axe_nnet)) } list(model = rs, imp = imps) diff --git a/R/out-of-bag.R b/R/out-of-bag.R index f9d11b7..119de15 100644 --- a/R/out-of-bag.R +++ b/R/out-of-bag.R @@ -48,7 +48,7 @@ compute_oob <- function(rs, oob) { purrr::map2_dfr(rs$model, rs$splits, .fn, met = oob) %>% dplyr::group_by(.metric, .estimator) %>% dplyr::summarize(.estimate = mean(.estimate, na.rm = TRUE)) %>% - mutate(.estimator = "out-of-bag") %>% + dplyr::mutate(.estimator = "out-of-bag") %>% dplyr::ungroup() } else { oob <- NULL diff --git a/R/reexports.R b/R/reexports.R deleted file mode 100644 index f4d6abe..0000000 --- a/R/reexports.R +++ /dev/null @@ -1,10 +0,0 @@ -#' @keywords internal -"_PACKAGE" - -#' @importFrom magrittr %>% -#' @export -magrittr::`%>%` - -#' @importFrom generics var_imp -#' @export -generics::var_imp diff --git a/R/validate.R b/R/validate.R new file mode 100644 index 0000000..abd3b6b --- /dev/null +++ b/R/validate.R @@ -0,0 +1,155 @@ +#' @include import-standalone-types-check.R +validate_args <- function(model, times, control, cost, call = rlang::caller_env()) { + model <- rlang::arg_match(model, baguette_models, error_arg = "base_model", + error_call = call) + + if (!is.null(cost) & !(model %in% c("CART", "C5.0"))) { + cli::cli_abort("When using misclassification costs, {.arg base_model} should + be either {.val CART} or {.val C5.0}.", call = call) + } + if (!is.matrix(cost)) { + check_number_decimal(cost, allow_null = TRUE, min = 0, call = call) + } else { + is_sq <- nrow(cost) == ncol(cost) + if (!is.numeric(cost) || !is_sq) { + cli::cli_abort("If {.arg cost} is a matrix, is must be numeric and square.", + call = call) + } + } + + check_number_whole(times, min = 2, call = call) + + validate_control(control, call = call) + + invisible(TRUE) + +} + +integer_B <- function(B) { + if (is.numeric(B) & !is.integer(B)) { + B <- as.integer(B) + } + B +} + +# ------------------------------------------------------------------------------ + +validate_y_type <- function(base_model, outcomes) { + hardhat::validate_outcomes_are_univariate(outcomes) + + if (base_model == "C5.0") { + hardhat::validate_outcomes_are_factors(outcomes) + } + +} + +# ------------------------------------------------------------------------------ + +model_failure <- function(x) { + if (inherits(x, "model_fit")) { + res <- inherits(x$fit, "try-error") + } else { + res <- inherits(x, "try-error") + } + res +} + +check_for_disaster <- function(x, call = rlang::caller_env()) { + x <- dplyr::mutate(x, passed = !purrr::map_lgl(model, model_failure)) + + if (sum(x$passed) == 0) { + if (inherits(x$model[[1]], "try-error")) { + msg <- as.character(x$model[[1]]) + } else { + if (inherits(x$model[[1]], "model_fit")) { + msg <- as.character(x$model[[1]]$fit) + } else msg <- NA + } + + if (!is.na(msg)) { + # escape any brackets in the error message + msg <- gsub("(\\{)", "\\1\\1", msg) + msg <- gsub("(\\})", "\\1\\1", msg) + msg <- cli::format_error(msg) + cli::cli_abort(c("All of the models failed. Example:", "x" = "{msg}")) + } else { + cli::cli_abort("All of the models failed.") + } + } + x +} + +# ------------------------------------------------------------------------------ + +check_type <- function(object, type, call = rlang::caller_env()) { + model_type <- object$base_model[2] + model_modes <- parsnip::get_from_env("modes") + if (is.null(type)) { + if (model_type == "classification") { + type <- "class" + } else if (model_type == "regression") { + type <- "numeric" + } + } else { + if (model_type == "classification") { + type <- rlang::arg_match(type, c("class", "prob"), error_call = call) + } else if (model_type == "regression") { + type <- rlang::arg_match(type, c("numeric"), error_call = call) + } else { + cli::cli_abort("Model mode {.val {model_type}} is not allowed + Possible values are {.or {.val {model_modes}}}.", + call = call) + } + } + type +} + +validate_importance <- function(x, call = rlang::caller_env()) { + if (is.null(x)) { + return(x) + } + + if (!is_tibble(x)) { + cli::cli_abort("Imprtance score results should be a tibble.", call = call) + } + + exp_cols <- c("term", "value", "std.error", "used") + if (!isTRUE(all.equal(exp_cols, names(x)))) { + cli::cli_abort("Importance columns should be: {.val {exp_cols}}.", call = call) + } + x +} + +# ------------------------------------------------------------------------------ + +validate_control <- function(x, call = rlang::caller_env()) { + if (!is.list(x)) { + cli::cli_abort("The control object should be a list created by + {.fn control_bag}.", call = call) + } + + check_bool(x$var_imp, arg = "var_imp", call = call) + check_bool(x$allow_parallel, arg = "allow_parallel", call = call) + x$sampling <- rlang::arg_match0(x$sampling, c("none", "down"), + arg_nm = "sampling", error_call = call) + check_bool(x$reduce, arg = "reduce", call = call) + check_function(x$extract, allow_null = TRUE, arg = "extract", call = call) + + x +} + +# ------------------------------------------------------------------------------ + +validate_case_weights <- function(weights, data, call = rlang::caller_env()) { + if (is.null(weights)) { + return(invisible(NULL)) + } + n <- nrow(data) + if (!is.vector(weights) || !is.numeric(weights) || length(weights) != n || + any(weights < 0)) { + cli::cli_abort("{.arg weights} should be a non-negative numeric vector + with the same size as the data.", call = call) + } + invisible(NULL) +} + diff --git a/R/zzz.R b/R/zzz.R new file mode 100644 index 0000000..0433e53 --- /dev/null +++ b/R/zzz.R @@ -0,0 +1,14 @@ +# nocov start + +# The functions below define the model information. These access the model +# environment inside of parsnip so they have to be executed once parsnip has +# been loaded. + +.onLoad <- function(libname, pkgname) { + # This defines model functions in the parsnip model database + make_bag_tree() + make_bag_mars() + make_bag_mlp() +} + +# nocov end diff --git a/man/bagger.Rd b/man/bagger.Rd index 224c31a..8364909 100644 --- a/man/bagger.Rd +++ b/man/bagger.Rd @@ -75,7 +75,7 @@ samples/ensemble members (some model fits might fail).} \item{control}{A list of options generated by \code{control_bag()}.} -\item{cost}{A non-negative scale (for two class problems) or a cost matrix.} +\item{cost}{A non-negative scale (for two class problems) or a square cost matrix.} \item{formula}{An object of class "formula" (or one that can be coerced to that class): a symbolic description of the model to be fitted. Note that diff --git a/man/baguette-package.Rd b/man/baguette-package.Rd index ca00654..a755281 100644 --- a/man/baguette-package.Rd +++ b/man/baguette-package.Rd @@ -1,13 +1,16 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/reexports.R +% Please edit documentation in R/baguette-package.R \docType{package} \name{baguette-package} -\alias{baguette} \alias{baguette-package} \title{baguette: Efficient Model Functions for Bagging} \description{ \if{html}{\figure{logo.png}{options: style='float: right' alt='logo' width='120'}} +Tree- and rule-based models can be bagged (\doi{10.1007/BF00058655}) using this package and their predictions equations are stored in an efficient format to reduce the model objects size and speed. + +\if{html}{\figure{logo.png}{options: style='float: right' alt='logo' width='120'}} + Tree- and rule-based models can be bagged (\doi{10.1007/BF00058655}) using this package and their predictions equations are stored in an efficient format to reduce the model objects size and speed. } \seealso{ @@ -18,6 +21,14 @@ Useful links: \item Report bugs at \url{https://github.com/tidymodels/baguette/issues} } + +Useful links: +\itemize{ + \item \url{https://baguette.tidymodels.org} + \item \url{https://github.com/tidymodels/baguette} + \item Report bugs at \url{https://github.com/tidymodels/baguette/issues} +} + } \author{ \strong{Maintainer}: Max Kuhn \email{max@posit.co} (\href{https://orcid.org/0000-0003-2402-136X}{ORCID}) diff --git a/man/reexports.Rd b/man/reexports.Rd index 991e371..4e64029 100644 --- a/man/reexports.Rd +++ b/man/reexports.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/reexports.R +% Please edit documentation in R/baguette-package.R \docType{import} \name{reexports} \alias{reexports} diff --git a/tests/testthat/_snaps/validation.md b/tests/testthat/_snaps/validation.md index 78278db..87b8374 100644 --- a/tests/testthat/_snaps/validation.md +++ b/tests/testthat/_snaps/validation.md @@ -1,4 +1,230 @@ +# bad values + + Code + baguette:::validate_args(model = "mars", times = 5L, control = control_bag(), + cost = NULL) + Condition + Error: + ! `base_model` must be one of "CART", "C5.0", "MARS", or "nnet", not "mars". + i Did you mean "MARS"? + +--- + + Code + baguette:::validate_args(model = "MARS", times = 1, control = control_bag(), + cost = NULL) + Condition + Error: + ! `times` must be a whole number larger than or equal to 2, not the number 1. + +--- + + Code + baguette:::validate_args(model = "MARS", times = -1L, control = control_bag(), + cost = NULL) + Condition + Error: + ! `times` must be a whole number larger than or equal to 2, not the number -1. + +--- + + Code + baguette:::validate_args(model = "MARS", times = 5L, control = 2, cost = NULL) + Condition + Error: + ! The control object should be a list created by `control_bag()`. + +--- + + Code + bagger(Sepal.Length ~ ., data = iris, times = 2L, base_model = "CART", cost = 2) + Condition + Error in `bagger()`: + ! `cost` can only be a scalar when there are two levels. + +# wrong y for C5 + + Code + bagger(Sepal.Length ~ ., data = iris, times = 2L, base_model = "C5.0") + Condition + Error in `validate_outcomes_are_factors()`: + ! All outcomes must be factors, but the following are not: + 'Sepal.Length': 'numeric' + +# validate imps + + Code + baguette:::validate_importance(tibble::tibble(terms = letters[1:2], value = 1:2, + std.error = 1:2)) + Condition + Error: + ! Importance columns should be: "term", "value", "std.error", and "used". + +--- + + Code + baguette:::validate_importance(data.frame(term = letters[1:2], value = 1:2, + std.error = 1:2)) + Condition + Error: + ! Imprtance score results should be a tibble. + # bad inputs - `type` should be either 'class' or 'prob' + Code + bagger(mpg ~ ., data = mtcars, base_model = letters[1:2]) + Condition + Error in `bagger()`: + ! `base_model` must be one of "CART", "C5.0", "MARS", or "nnet", not "a". + +--- + + Code + bagger(mpg ~ ., data = mtcars, base_model = "MARS", cost = 2) + Condition + Error in `bagger()`: + ! When using misclassification costs, `base_model` should be either "CART" or "C5.0". + +--- + + Code + bagger(mpg ~ ., data = mtcars, base_model = "CART", cost = -2) + Condition + Error in `bagger()`: + ! `cost` must be a number larger than or equal to 0 or `NULL`, not the number -2. + +--- + + Code + bagger(mpg ~ ., data = mtcars, base_model = "CART", cost = matrix(1, ncol = 2, + nrow = 1)) + Condition + Error in `bagger()`: + ! If `cost` is a matrix, is must be numeric and square. + +--- + + Code + bagger(mpg ~ ., data = mtcars, base_model = "MARS", control = control_bag( + extract = 2)) + Condition + Error in `control_bag()`: + ! `extract` must be a function or `NULL`, not the number 2. + +--- + + Code + bagger(mpg ~ ., data = mtcars, base_model = "C5.0") + Condition + Error in `validate_outcomes_are_factors()`: + ! All outcomes must be factors, but the following are not: + 'mpg': 'numeric' + +--- + + Code + bagger(wt + mpg ~ ., data = mtcars, base_model = "MARS") + Condition + Error in `validate_outcomes_are_univariate()`: + ! The outcome must be univariate, but 2 columns were found. + +--- + + Code + predict(bagger(mpg ~ ., data = mtcars, base_model = "MARS"), mtcars[1:2, -1], + type = "potato") + Condition + Error in `predict()`: + ! `type` must be one of "numeric", not "potato". + +--- + + Code + set.seed(3983) + predict(bagger(Class ~ ., data = two_class_dat, base_model = "MARS"), + two_class_dat[1:2, -3], type = "topepo") + Condition + Warning: + There were 2 warnings in `dplyr::mutate()`. + The first warning was: + i In argument: `model = iter(...)`. + Caused by warning: + ! glm.fit: fitted probabilities numerically 0 or 1 occurred + i Run `dplyr::last_dplyr_warnings()` to see the 1 remaining warning. + Error in `predict()`: + ! `type` must be one of "class" or "prob", not "topepo". + +# model failures inputs + + Code + set.seed(459394) + bagger(a ~ ., data = bad_iris, base_model = "CART", times = 3) + Condition + Error in `check_for_disaster()`: + ! All of the models failed. Example: + x Error in cbind(yval2, yprob, nodeprob) : number of rows of matrices must match (see arg 2) + +# control inputs + + Code + control_bag(reduce = 1:2) + Condition + Error in `control_bag()`: + ! `reduce` must be `TRUE` or `FALSE`, not an integer vector. + +--- + + Code + control_bag(reduce = 1) + Condition + Error in `control_bag()`: + ! `reduce` must be `TRUE` or `FALSE`, not the number 1. + +--- + + Code + control_bag(sampling = rep("none", 2)) + Condition + Error in `control_bag()`: + ! `arg` must be length 1 or a permutation of `c("none", "down")`. + +--- + + Code + control_bag(sampling = 1) + Condition + Error in `control_bag()`: + ! `sampling` must be a string or character vector. + +--- + + Code + control_bag(allow_parallel = 1:2) + Condition + Error in `control_bag()`: + ! `allow_parallel` must be `TRUE` or `FALSE`, not an integer vector. + +--- + + Code + control_bag(allow_parallel = 1) + Condition + Error in `control_bag()`: + ! `allow_parallel` must be `TRUE` or `FALSE`, not the number 1. + +--- + + Code + control_bag(var_imp = 1:2) + Condition + Error in `control_bag()`: + ! `var_imp` must be `TRUE` or `FALSE`, not an integer vector. + +--- + + Code + control_bag(var_imp = 1) + Condition + Error in `control_bag()`: + ! `var_imp` must be `TRUE` or `FALSE`, not the number 1. diff --git a/tests/testthat/test-validation.R b/tests/testthat/test-validation.R index a871251..1b940bf 100644 --- a/tests/testthat/test-validation.R +++ b/tests/testthat/test-validation.R @@ -1,86 +1,96 @@ test_that('good values', { - expect_error( + expect_no_error( baguette:::validate_args( model = "MARS", times = 5L, control = control_bag(), cost = NULL - ), - regexp = NA + ) ) }) test_that('bad values', { - expect_error( + expect_snapshot( baguette:::validate_args( model = "mars", times = 5L, control = control_bag(), - cost = NULL - ), - regexp = "`base_model`", - class = "rlang_error" + cost = NULL), + error = TRUE ) - expect_error( + + expect_snapshot( baguette:::validate_args( model = "MARS", times = 1, control = control_bag(), cost = NULL ), - regexp = "integer" + error = TRUE ) - expect_error( + + expect_snapshot( baguette:::validate_args( model = "MARS", times = -1L, control = control_bag(), cost = NULL ), - regexp = "integer" + error = TRUE ) - expect_error( + + expect_snapshot( baguette:::validate_args( model = "MARS", times = 5L, control = 2, cost = NULL ), - regexp = "should be a list" + error = TRUE ) - expect_error( - bagger(Sepal.Length ~ ., data = iris, times = 2L, base_model = "CART", cost = 2), - regexp = "`cost` can only be a scalar" + + expect_snapshot( + bagger( + Sepal.Length ~ ., + data = iris, + times = 2L, + base_model = "CART", + cost = 2 + ), + error = TRUE ) + }) # ------------------------------------------------------------------------------ test_that('wrong y for C5', { - expect_error( + expect_snapshot( bagger(Sepal.Length ~ ., data = iris, times = 2L, base_model = "C5.0"), - regexp = "must be factors" + error = TRUE ) }) # ------------------------------------------------------------------------------ test_that('validate imps', { - expect_error( + + expect_snapshot( baguette:::validate_importance( tibble::tibble( terms = letters[1:2], value = 1:2, std.error = 1:2 ) - ) + ), + error = TRUE ) - expect_error( + + expect_snapshot( baguette:::validate_importance( - data.frame(term = letters[1:2], - value = 1:2, - std.error = 1:2) - ) + data.frame(term = letters[1:2], value = 1:2, std.error = 1:2) + ), + error = TRUE ) }) @@ -90,49 +100,71 @@ test_that('bad inputs', { skip_if_not_installed("earth") skip_if_not_installed("modeldata") - expect_error( + expect_snapshot( bagger(mpg ~ ., data = mtcars, base_model = letters[1:2]), - "should be a single character value." + error = TRUE ) - expect_error( + + expect_snapshot( bagger(mpg ~ ., data = mtcars, base_model = "MARS", cost = 2), - "should be either 'CART' or 'C5.0'" + error = TRUE ) - expect_error( + + expect_snapshot( bagger(mpg ~ ., data = mtcars, base_model = "CART", cost = -2), - "`cost` should be non-negative" + error = TRUE ) - expect_error( - bagger(mpg ~ ., data = mtcars, base_model = "CART", cost = matrix(-2, ncol = 2, nrow = 2)), - "`cost` should be non-negative" + + expect_snapshot( + bagger( + mpg ~ ., + data = mtcars, + base_model = "CART", + cost = matrix(1, ncol = 2, nrow = 1) + ), + error = TRUE ) - expect_error( - bagger(mpg ~ ., data = mtcars, base_model = "MARS", control = control_bag(extract = 2)), - "`extract` should be NULL or a function" + + expect_snapshot( + bagger( + mpg ~ ., + data = mtcars, + base_model = "MARS", + control = control_bag(extract = 2) + ), + error = TRUE ) - expect_error( + + expect_snapshot( bagger(mpg ~ ., data = mtcars, base_model = "C5.0"), - "All outcomes must be factors, but the following are not" + error = TRUE ) - expect_error( + + expect_snapshot( bagger(wt + mpg ~ ., data = mtcars, base_model = "MARS"), - "The outcome must be univariate" + error = TRUE ) - expect_error( + + expect_snapshot( predict(bagger(mpg ~ ., data = mtcars, base_model = "MARS"), mtcars[1:2, -1], type = "potato"), - "`type` should be 'numeric'" + error = TRUE ) + if (compareVersion(as.character(getRversion()), "3.6.0") > 0) { expect_warning(RNGkind(sample.kind = "Rounding")) } - set.seed(3983) - expect_snapshot_error( + + expect_snapshot({ + set.seed(3983) predict(bagger(Class ~ ., data = two_class_dat, base_model = "MARS"), two_class_dat[1:2, -3], - type = "topepo"), - ) + type = "topepo") + }, + error = TRUE + ) + }) # ------------------------------------------------------------------------------ @@ -144,48 +176,57 @@ test_that('model failures inputs', { if (compareVersion(as.character(getRversion()), "3.6.0") > 0) { expect_warning(RNGkind(sample.kind = "Rounding")) } - set.seed(459394) - expect_error( - bagger(a ~ ., data = bad_iris, base_model = "CART", times = 3), - "All of the models failed" - ) + expect_snapshot({ + set.seed(459394) + bagger(a ~ ., data = bad_iris, base_model = "CART", times = 3) + }, + error = TRUE + ) }) - # ------------------------------------------------------------------------------ test_that('control inputs', { - expect_error( - control_bag(var_imp = 1), - "`var_imp` should be a single logical value." - ) - expect_error( - control_bag(var_imp = 1:2), - "`var_imp` should be a single logical value." + + expect_snapshot( + control_bag(reduce = 1:2), + error = TRUE ) - expect_error( - control_bag(allow_parallel = 1), - "`allow_parallel` should be a single logical value." + + expect_snapshot( + control_bag(reduce = 1), + error = TRUE ) - expect_error( - control_bag(allow_parallel = 1:2), - "`allow_parallel` should be a single logical value." + + expect_snapshot( + control_bag(sampling = rep("none", 2)), + error = TRUE ) - expect_error( + + expect_snapshot( control_bag(sampling = 1), - "`sampling` should be either 'none' or 'down'" + error = TRUE ) - expect_error( - control_bag(sampling = rep("none", 2)), - "`sampling` should be either 'none' or 'down'" + + expect_snapshot( + control_bag(allow_parallel = 1:2), + error = TRUE ) - expect_error( - control_bag(reduce = 1), - "`reduce` should be a single logical value." + + expect_snapshot( + control_bag(allow_parallel = 1), + error = TRUE ) - expect_error( - control_bag(reduce = 1:2), - "`reduce` should be a single logical value." + + expect_snapshot( + control_bag(var_imp = 1:2), + error = TRUE ) + + expect_snapshot( + control_bag(var_imp = 1), + error = TRUE + ) + })