Skip to content

Commit

Permalink
changes for #3
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo committed Jul 16, 2019
1 parent 091ce4e commit a82446b
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 39 deletions.
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ importFrom(earth,earth)
importFrom(earth,evimp)
importFrom(furrr,future_map)
importFrom(furrr,future_map2)
importFrom(hardhat,validate_outcomes_is_univariate)
importFrom(magrittr,"%>%")
importFrom(parsnip,decision_tree)
importFrom(parsnip,fit)
Expand All @@ -47,5 +46,6 @@ importFrom(stats,setNames)
importFrom(tibble,as_tibble)
importFrom(tibble,is_tibble)
importFrom(tibble,tibble)
importFrom(tidypredict,tidypredict_fit)
importFrom(utils,globalVariables)
importFrom(withr,with_seed)
3 changes: 2 additions & 1 deletion R/bagger.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#' model function. A list of possible arguments per model are given in Details.
#' @param extract A function (or NULL) that can extract model-related aspects
#' of each ensemble member. See Details and example below.
#' @param control A list of options generated by `bag_control()`.
#' @param ... Optional arguments to pass to the `extract` function.
#' @details `bagger()` fits separate models to bootstrap samples. The
#' prediction function for each model object is encoded in an R expression and
Expand Down Expand Up @@ -84,7 +85,7 @@ bagger.data.frame <-
y,
model = "CART",
B = 10L,
opt = NULLL,
opt = NULL,
control = bag_control(),
extract = NULL,
...) {
Expand Down
3 changes: 1 addition & 2 deletions R/bridge.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#' @importFrom hardhat validate_outcomes_is_univariate
#' @importFrom rsample bootstraps


bagger_bridge <- function(processed, model, seed, B, opt, control, extract, ...) {
validate_outcomes_is_univariate(processed$outcomes)
Expand Down
27 changes: 11 additions & 16 deletions R/cart.R
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
#' @importFrom rpart rpart
#' @importFrom rsample analysis
#' @importFrom purrr map map2 map_df
#' @importFrom tibble tibble
#' @importFrom parsnip decision_tree
#' @importFrom furrr future_map
#' @importFrom partykit as.party.rpart

cart_bagger <- function(rs, opt, control, extract, ...) {
is_classif <- is.factor(rs$splits[[1]]$data$.outcome)
Expand All @@ -14,7 +7,14 @@ cart_bagger <- function(rs, opt, control, extract, ...) {

rs <-
rs %>%
dplyr::mutate(model = iter(fit_seed, splits, seed_fit, .fn = cart_fit, spec = mod_spec))
dplyr::mutate(model = iter(
fit_seed,
splits,
seed_fit,
.fn = cart_fit,
spec = mod_spec,
control = control
))

rs <- check_for_disaster(rs)

Expand Down Expand Up @@ -86,22 +86,17 @@ make_cart_spec <- function(classif, opt) {
cart_spec
}

#' @importFrom stats complete.cases

cart_fit <- function(split, spec, sampling = "none") {
cart_fit <- function(split, spec, control = bag_control()) {

dat <- rsample::analysis(split)

if (sampling == "down") {
if (control$sampling == "down") {
dat <- down_sampler(dat)
}

ctrl <- parsnip::fit_control(catch = TRUE)
mod <-
parsnip::fit.model_spec(spec,
.outcome ~ .,
data = rsample::analysis(split),
control = ctrl)
mod <- parsnip::fit.model_spec(spec, .outcome ~ ., data = dat, control = ctrl)
mod
}

Expand Down
25 changes: 16 additions & 9 deletions R/mars.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
#' @importFrom earth earth evimp
#' @importFrom rsample analysis
#' @importFrom purrr map map2 map_df
#' @importFrom tibble tibble
#' @importFrom parsnip mars
#' @importFrom furrr future_map

mars_bagger <- function(rs, opt, control, extract, ...) {

Expand All @@ -14,7 +8,14 @@ mars_bagger <- function(rs, opt, control, extract, ...) {

rs <-
rs %>%
dplyr::mutate(model = iter(fit_seed, splits, seed_fit, .fn = mars_fit, spec = mod_spec))
dplyr::mutate(model = iter(
fit_seed,
splits,
seed_fit,
.fn = mars_fit,
spec = mod_spec,
control = control
))

rs <- check_for_disaster(rs)

Expand Down Expand Up @@ -54,13 +55,19 @@ make_mars_spec <- function(classif, opt) {
mars_spec
}

#' @importFrom stats complete.cases

mars_fit <- function(split, spec) {

mars_fit <- function(split, spec, control = bag_control()) {
ctrl <- parsnip::fit_control(catch = TRUE)

dat <- rsample::analysis(split)
# only na.fail is supported by earth::earth
dat <- dat[complete.cases(dat),, drop = FALSE]

if (control$sampling == "down") {
dat <- down_sampler(dat)
}

mod <- parsnip::fit.model_spec(spec, .outcome ~ ., data = dat, control = ctrl)
mod
}
Expand Down
21 changes: 18 additions & 3 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ join_args <- function(default, others) {

failed_stats <- tibble(.metric = "failed", .estiamtor = "none", .estimate = NA_real_)

#' @importFrom rsample assessment
#' @importFrom stats setNames sd predict
oob_parsnip <- function(model, split, met) {
dat <- rsample::assessment(split)
y <- dat$.outcome
Expand Down Expand Up @@ -146,11 +144,28 @@ filter_rs <- function(rs) {

# ------------------------------------------------------------------------------

#' @importFrom withr with_seed

seed_fit <- function(seed, split, .fn, ...) {
withr::with_seed(seed, .fn(split, ...))
}

# ------------------------------------------------------------------------------

down_sampler <- function(x) {

if (!is.factor(x$.outcome)) {
warning("Down-sampling is only used in classification models.", call. = FALSE)
return(x)
}

min_n <- min(table(x$.outcome))
x %>%
group_by(.outcome) %>%
sample_n(size = min_n, replace = TRUE) %>%
ungroup()
}



# ------------------------------------------------------------------------------

Expand Down
8 changes: 2 additions & 6 deletions R/validate.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,10 @@ validate_y_type <- function(model, outcomes) {
}

if (model == "model_rules") {
if (!is.numeric(outcomes[[1]]))
stop("Outcome data must be numeric for model rules.", call. = FALSE)
#hardhat::validate_outcomes_is_numeric(outcomes)
hardhat::validate_outcomes_are_binary(outcomes)
}
if (model == "C5.0") {
if (!is.factor(outcomes[[1]]))
stop("Outcome data must be a factor for C5.0.", call. = FALSE)
#hardhat::validate_outcomes_is_factor(outcomes)
hardhat::validate_outcomes_are_factors(outcomes)
}

}
Expand Down
4 changes: 3 additions & 1 deletion man/bagger.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit a82446b

Please sign in to comment.