Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve error messages #80

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Version: 1.0.2.9000
Authors@R: c(
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre"),
comment = c(ORCID = "0000-0003-2402-136X")),
person(given = "Posit Software, PBC", role = c("cph", "fnd"))
person("Posit Software, PBC", role = c("cph", "fnd"))
)
Description: Tree- and rule-based models can be bagged
(<doi:10.1007/BF00058655>) using this package and their predictions
Expand All @@ -28,17 +28,16 @@ Imports:
hardhat (>= 1.1.0),
magrittr,
purrr,
rlang,
rlang (>= 1.1.0),
rpart,
rsample,
tibble,
tidyr,
utils,
withr
Suggests:
AmesHousing,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm going to assume that this is leftover from a long time ago, and that we don't need it since we have modeldata

covr,
earth,
earth,
modeldata,
nnet,
recipes,
Expand All @@ -48,7 +47,31 @@ Suggests:
yardstick
Config/Needs/website: tidyverse/tidytemplate
Config/testthat/edition: 3
Config/usethis/last-upkeep: 2024-10-23
Encoding: UTF-8
Language: en-US
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.2
Collate:
'C5.0.R'
'bag_mars_data.R'
'bag_nnet_data.R'
'bag_tree_data.R'
'import-standalone-types-check.R'
'validate.R'
'bagger.R'
'baguette-package.R'
'bridge.R'
'cart.R'
'class_cost.R'
'constructor.R'
'cost_models.R'
'import-standalone-obj-type.R'
'mars.R'
'misc.R'
'model_info.R'
'nnet.R'
'out-of-bag.R'
'predict.R'
'var_imp.R'
'zzz.R'
2 changes: 1 addition & 1 deletion R/C5.0.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ c5_bagger <- function(rs, control, ...) {
if (control$reduce) {
rs <-
rs %>%
mutate(model = map(model, axe_C5))
dplyr::mutate(model = purrr::map(model, axe_C5))
}

list(model = rs, imp = imps)
Expand Down
160 changes: 0 additions & 160 deletions R/aaa_validate.R

This file was deleted.

2 changes: 1 addition & 1 deletion R/bag_mars_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# they are already in the parsnip model database. We'll exclude them from
# coverage stats for this reason.

# nocov
# nocov start

make_bag_mars <- function() {

Expand Down
2 changes: 1 addition & 1 deletion R/bag_nnet_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# they are already in the parsnip model database. We'll exclude them from
# coverage stats for this reason.

# nocov
# nocov start

make_bag_mlp <- function() {

Expand Down
2 changes: 1 addition & 1 deletion R/bag_tree_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# they are already in the parsnip model database. We'll exclude them from
# coverage stats for this reason.

# nocov
# nocov start

make_bag_tree <- function() {

Expand Down
3 changes: 2 additions & 1 deletion R/bagger.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#' @param times A single integer greater than 1 for the maximum number of bootstrap
#' samples/ensemble members (some model fits might fail).
#' @param control A list of options generated by `control_bag()`.
#' @param cost A non-negative scale (for two class problems) or a cost matrix.
#' @param cost A non-negative scale (for two class problems) or a square cost matrix.
#' @param ... Optional arguments to pass to the base model function.
#' @details `bagger()` fits separate models to bootstrap samples. The
#' prediction function for each model object is encoded in an R expression and
Expand Down Expand Up @@ -89,6 +89,7 @@
#' cart_pca_bag
#' }
#' @export
#' @include validate.R
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DOWN WITH aaa FILES

bagger <- function(x, ...) {
UseMethod("bagger")
}
Expand Down
34 changes: 18 additions & 16 deletions R/aaa.R → R/baguette-package.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#' @keywords internal
"_PACKAGE"

## usethis namespace: start
#' @import rlang
#' @import dplyr
#' @import hardhat
#'
#' @importFrom parsnip set_engine fit fit_xy control_parsnip mars decision_tree
#' @importFrom parsnip set_new_model multi_predict update_dot_check show_fit
#' @importFrom parsnip new_model_spec null_value update_main_parameters
Expand All @@ -17,8 +20,20 @@
#' @importFrom withr with_seed
#' @importFrom dials new_quant_param
#' @importFrom stats coef
#'
# ------------------------------------------------------------------------------

#' @keywords internal
"_PACKAGE"

#' @importFrom magrittr %>%
#' @export
magrittr::`%>%`

#' @importFrom generics var_imp
#' @export
generics::var_imp

## usethis namespace: end
NULL

utils::globalVariables(
c(
Expand All @@ -42,16 +57,3 @@ utils::globalVariables(
".estimator"
)
)

# ------------------------------------------------------------------------------

# The functions below define the model information. These access the model
# environment inside of parsnip so they have to be executed once parsnip has
# been loaded.

.onLoad <- function(libname, pkgname) {
# This defines model functions in the parsnip model database
make_bag_tree()
make_bag_mars()
make_bag_mlp()
}
20 changes: 4 additions & 16 deletions R/bridge.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
bagger_bridge <- function(processed, weights, base_model, seed, times, control, cost, ...) {
bagger_bridge <- function(processed, weights, base_model, seed, times, control,
cost, ..., call = rlang::caller_env()) {
validate_outcomes_are_univariate(processed$outcomes)
if (base_model %in% c("C5.0")) {
validate_outcomes_are_factors(processed$outcomes)
Expand Down Expand Up @@ -27,8 +28,8 @@ bagger_bridge <- function(processed, weights, base_model, seed, times, control,
} else {
res <- switch(
base_model,
CART = cost_sens_cart_bagger(rs, control, cost, ...),
C5.0 = cost_sens_c5_bagger(rs, control, cost, ...)
CART = cost_sens_cart_bagger(rs, control, cost, ..., call = call),
C5.0 = cost_sens_c5_bagger(rs, control, cost, ..., call = call)
)
}

Expand All @@ -43,16 +44,3 @@ bagger_bridge <- function(processed, weights, base_model, seed, times, control,
)
res
}

validate_case_weights <- function(weights, data) {
if (is.null(weights)) {
return(invisible(NULL))
}
n <- nrow(data)
if (!is.vector(weights) || !is.numeric(weights) || length(weights) != n ||
any(weights < 0)) {
cli::cli_abort("'weights' should be a non-negative numeric vector with the same size as the data.")
}
invisible(NULL)
}

2 changes: 1 addition & 1 deletion R/cart.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ cart_bagger <- function(rs, control, ...) {
if (control$reduce) {
rs <-
rs %>%
mutate(model = map(model, axe_cart))
dplyr::mutate(model = purrr::map(model, axe_cart))
}

list(model = rs, imp = imps)
Expand Down
6 changes: 4 additions & 2 deletions R/constructor.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
new_bagger <- function(model_df, imp, control, cost, base_model, blueprint) {
new_bagger <- function(model_df, imp, control, cost, base_model, blueprint,
call = rlang::caller_env()) {

if (!is_tibble(model_df)) {
cli::cli_abort("`model_df` should be a tibble.")
cli::cli_abort("{.arg model_df} should be {.cls tibble}.", call = call)
}

# TODO extend to use mode from model object(s)
if (is.numeric(blueprint$ptypes$outcomes[[1]])) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could eventually extend baguette to censored regression models. There is a lot of code that looks at the data to get the mode or assumes that they are either regression or classification.

mod_mode <- "regression"
} else {
Expand Down
Loading
Loading