From a82446b1b548f5253be4d57e973be0f70c05edd4 Mon Sep 17 00:00:00 2001 From: topepo Date: Tue, 16 Jul 2019 19:51:13 -0400 Subject: [PATCH] changes for #3 --- NAMESPACE | 2 +- R/bagger.R | 3 ++- R/bridge.R | 3 +-- R/cart.R | 27 +++++++++++---------------- R/mars.R | 25 ++++++++++++++++--------- R/misc.R | 21 ++++++++++++++++++--- R/validate.R | 8 ++------ man/bagger.Rd | 4 +++- 8 files changed, 54 insertions(+), 39 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index e39b819..d8bebcc 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -22,7 +22,6 @@ importFrom(earth,earth) importFrom(earth,evimp) importFrom(furrr,future_map) importFrom(furrr,future_map2) -importFrom(hardhat,validate_outcomes_is_univariate) importFrom(magrittr,"%>%") importFrom(parsnip,decision_tree) importFrom(parsnip,fit) @@ -47,5 +46,6 @@ importFrom(stats,setNames) importFrom(tibble,as_tibble) importFrom(tibble,is_tibble) importFrom(tibble,tibble) +importFrom(tidypredict,tidypredict_fit) importFrom(utils,globalVariables) importFrom(withr,with_seed) diff --git a/R/bagger.R b/R/bagger.R index 13755d8..0ee9f17 100644 --- a/R/bagger.R +++ b/R/bagger.R @@ -21,6 +21,7 @@ #' model function. A list of possible arguments per model are given in Details. #' @param extract A function (or NULL) that can extract model-related aspects #' of each ensemble member. See Details and example below. +#' @param control A list of options generated by `bag_control()`. #' @param ... Optional arguments to pass to the `extract` function. #' @details `bagger()` fits separate models to bootstrap samples. The #' prediction function for each model object is encoded in an R expression and @@ -84,7 +85,7 @@ bagger.data.frame <- y, model = "CART", B = 10L, - opt = NULLL, + opt = NULL, control = bag_control(), extract = NULL, ...) { diff --git a/R/bridge.R b/R/bridge.R index c281acf..0b77dd0 100644 --- a/R/bridge.R +++ b/R/bridge.R @@ -1,5 +1,4 @@ -#' @importFrom hardhat validate_outcomes_is_univariate -#' @importFrom rsample bootstraps + bagger_bridge <- function(processed, model, seed, B, opt, control, extract, ...) { validate_outcomes_is_univariate(processed$outcomes) diff --git a/R/cart.R b/R/cart.R index 3dbe147..6a7b3e7 100644 --- a/R/cart.R +++ b/R/cart.R @@ -1,10 +1,3 @@ -#' @importFrom rpart rpart -#' @importFrom rsample analysis -#' @importFrom purrr map map2 map_df -#' @importFrom tibble tibble -#' @importFrom parsnip decision_tree -#' @importFrom furrr future_map -#' @importFrom partykit as.party.rpart cart_bagger <- function(rs, opt, control, extract, ...) { is_classif <- is.factor(rs$splits[[1]]$data$.outcome) @@ -14,7 +7,14 @@ cart_bagger <- function(rs, opt, control, extract, ...) { rs <- rs %>% - dplyr::mutate(model = iter(fit_seed, splits, seed_fit, .fn = cart_fit, spec = mod_spec)) + dplyr::mutate(model = iter( + fit_seed, + splits, + seed_fit, + .fn = cart_fit, + spec = mod_spec, + control = control + )) rs <- check_for_disaster(rs) @@ -86,22 +86,17 @@ make_cart_spec <- function(classif, opt) { cart_spec } -#' @importFrom stats complete.cases -cart_fit <- function(split, spec, sampling = "none") { +cart_fit <- function(split, spec, control = bag_control()) { dat <- rsample::analysis(split) - if (sampling == "down") { + if (control$sampling == "down") { dat <- down_sampler(dat) } ctrl <- parsnip::fit_control(catch = TRUE) - mod <- - parsnip::fit.model_spec(spec, - .outcome ~ ., - data = rsample::analysis(split), - control = ctrl) + mod <- parsnip::fit.model_spec(spec, .outcome ~ ., data = dat, control = ctrl) mod } diff --git a/R/mars.R b/R/mars.R index 8c7784b..a54630c 100644 --- a/R/mars.R +++ b/R/mars.R @@ -1,9 +1,3 @@ -#' @importFrom earth earth evimp -#' @importFrom rsample analysis -#' @importFrom purrr map map2 map_df -#' @importFrom tibble tibble -#' @importFrom parsnip mars -#' @importFrom furrr future_map mars_bagger <- function(rs, opt, control, extract, ...) { @@ -14,7 +8,14 @@ mars_bagger <- function(rs, opt, control, extract, ...) { rs <- rs %>% - dplyr::mutate(model = iter(fit_seed, splits, seed_fit, .fn = mars_fit, spec = mod_spec)) + dplyr::mutate(model = iter( + fit_seed, + splits, + seed_fit, + .fn = mars_fit, + spec = mod_spec, + control = control + )) rs <- check_for_disaster(rs) @@ -54,13 +55,19 @@ make_mars_spec <- function(classif, opt) { mars_spec } -#' @importFrom stats complete.cases -mars_fit <- function(split, spec) { + +mars_fit <- function(split, spec, control = bag_control()) { ctrl <- parsnip::fit_control(catch = TRUE) + dat <- rsample::analysis(split) # only na.fail is supported by earth::earth dat <- dat[complete.cases(dat),, drop = FALSE] + + if (control$sampling == "down") { + dat <- down_sampler(dat) + } + mod <- parsnip::fit.model_spec(spec, .outcome ~ ., data = dat, control = ctrl) mod } diff --git a/R/misc.R b/R/misc.R index a3e71d3..99d9254 100644 --- a/R/misc.R +++ b/R/misc.R @@ -10,8 +10,6 @@ join_args <- function(default, others) { failed_stats <- tibble(.metric = "failed", .estiamtor = "none", .estimate = NA_real_) -#' @importFrom rsample assessment -#' @importFrom stats setNames sd predict oob_parsnip <- function(model, split, met) { dat <- rsample::assessment(split) y <- dat$.outcome @@ -146,11 +144,28 @@ filter_rs <- function(rs) { # ------------------------------------------------------------------------------ -#' @importFrom withr with_seed + seed_fit <- function(seed, split, .fn, ...) { withr::with_seed(seed, .fn(split, ...)) } +# ------------------------------------------------------------------------------ + +down_sampler <- function(x) { + + if (!is.factor(x$.outcome)) { + warning("Down-sampling is only used in classification models.", call. = FALSE) + return(x) + } + + min_n <- min(table(x$.outcome)) + x %>% + group_by(.outcome) %>% + sample_n(size = min_n, replace = TRUE) %>% + ungroup() +} + + # ------------------------------------------------------------------------------ diff --git a/R/validate.R b/R/validate.R index 46c6e72..e7853a1 100644 --- a/R/validate.R +++ b/R/validate.R @@ -66,14 +66,10 @@ validate_y_type <- function(model, outcomes) { } if (model == "model_rules") { - if (!is.numeric(outcomes[[1]])) - stop("Outcome data must be numeric for model rules.", call. = FALSE) - #hardhat::validate_outcomes_is_numeric(outcomes) + hardhat::validate_outcomes_are_binary(outcomes) } if (model == "C5.0") { - if (!is.factor(outcomes[[1]])) - stop("Outcome data must be a factor for C5.0.", call. = FALSE) - #hardhat::validate_outcomes_is_factor(outcomes) + hardhat::validate_outcomes_are_factors(outcomes) } } diff --git a/man/bagger.Rd b/man/bagger.Rd index 34d8734..513538e 100644 --- a/man/bagger.Rd +++ b/man/bagger.Rd @@ -14,7 +14,7 @@ bagger(x, ...) \method{bagger}{default}(x, ...) \method{bagger}{data.frame}(x, y, model = "CART", B = 10L, - opt = NULL, var_imp = FALSE, oob = NULL, extract = NULL, ...) + opt = NULL, control = bag_control(), extract = NULL, ...) \method{bagger}{matrix}(x, y, model = "CART", B = 10L, opt = NULL, control = bag_control(), extract = NULL, ...) @@ -43,6 +43,8 @@ samples/ensemble members (some model fits might fail).} \item{opt}{A named list (or NULL) of arguments to pass to the underlying model function. A list of possible arguments per model are given in Details.} +\item{control}{A list of options generated by \code{bag_control()}.} + \item{extract}{A function (or NULL) that can extract model-related aspects of each ensemble member. See Details and example below.}