diff --git a/NAMESPACE b/NAMESPACE index 64a8f5365..0782892d6 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -202,7 +202,6 @@ export(bag_mars) export(bag_mlp) export(bag_tree) export(bart) -export(bartMachine_interval_calc) export(boost_tree) export(case_weights_allowed) export(cforest_train) diff --git a/R/arguments.R b/R/arguments.R index 156916391..3142ce35a 100644 --- a/R/arguments.R +++ b/R/arguments.R @@ -258,7 +258,7 @@ make_form_call <- function(object, env = NULL) { } # TODO we need something to indicate that case weights are being used. -make_xy_call <- function(object, target, env) { +make_xy_call <- function(object, target, env, call = rlang::caller_env()) { fit_args <- object$method$fit$args uses_weights <- has_weights(env) @@ -283,7 +283,7 @@ make_xy_call <- function(object, target, env) { data.frame = rlang::expr(maybe_data_frame(x)), matrix = rlang::expr(maybe_matrix(x)), dgCMatrix = rlang::expr(maybe_sparse_matrix(x)), - cli::cli_abort("Invalid data type target: {target}.") + cli::cli_abort("Invalid data type target: {target}.", call = call) ) if (uses_weights) { object$method$fit$args[[ unname(data_args["weights"]) ]] <- rlang::expr(weights) diff --git a/R/autoplot.R b/R/autoplot.R index 5ace8f2e0..a880ed4c5 100644 --- a/R/autoplot.R +++ b/R/autoplot.R @@ -41,14 +41,15 @@ autoplot.glmnet <- function(object, ..., min_penalty = 0, best_penalty = NULL, } -map_glmnet_coefs <- function(x) { +map_glmnet_coefs <- function(x, call = rlang::caller_env()) { coefs <- coef(x) # If parsnip is used to fit the model, glmnet should be attached and this will # work. If an object is loaded from a new session, they will need to load the # package. if (is.null(coefs)) { cli::cli_abort( - "Please load the {.pkg glmnet} package before running {.fun autoplot}." + "Please load the {.pkg glmnet} package before running {.fun autoplot}.", + call = call ) } p <- x$dim[1] @@ -89,9 +90,10 @@ top_coefs <- function(x, top_n = 5) { dplyr::slice(seq_len(top_n)) } -autoplot_glmnet <- function(x, min_penalty = 0, best_penalty = NULL, top_n = 3L, ...) { +autoplot_glmnet <- function(x, min_penalty = 0, best_penalty = NULL, top_n = 3L, + call = rlang::caller_env(), ...) { tidy_coefs <- - map_glmnet_coefs(x) %>% + map_glmnet_coefs(x, call = call) %>% dplyr::filter(penalty >= min_penalty) actual_min_penalty <- min(tidy_coefs$penalty) diff --git a/R/bart.R b/R/bart.R index 0070f702e..f4e54657c 100644 --- a/R/bart.R +++ b/R/bart.R @@ -130,61 +130,13 @@ update.bart <- ) } - #' Developer functions for predictions via BART models -#' @export -#' @keywords internal #' @name bart-internal #' @inheritParams predict.model_fit #' @param obj A parsnip object. -#' @param ci Confidence (TRUE) or prediction interval (FALSE) #' @param level Confidence level. #' @param std_err Attach column for standard error of prediction or not. -bartMachine_interval_calc <- function(new_data, obj, ci = TRUE, level = 0.95) { - if (obj$spec$mode == "classification") { - cli::cli_abort( - "Prediction intervals are not possible for classification" - ) - } - get_std_err <- obj$spec$method$pred$pred_int$extras$std_error - - if (ci) { - cl <- - rlang::call2( - "calc_credible_intervals", - .ns = "bartMachine", - bart_machine = rlang::expr(obj$fit), - new_data = rlang::expr(new_data), - ci_conf = level - ) - - } else { - cl <- - rlang::call2( - "calc_prediction_intervals", - .ns = "bartMachine", - bart_machine = rlang::expr(obj$fit), - new_data = rlang::expr(new_data), - pi_conf = level - ) - } - res <- rlang::eval_tidy(cl) - if (!ci) { - if (get_std_err) { - .std_error <- apply(res$all_prediction_samples, 1, stats::sd, na.rm = TRUE) - } - res <- res$interval - } - res <- tibble::as_tibble(res) - names(res) <- c(".pred_lower", ".pred_upper") - if (!ci & get_std_err) { - res$.std_err <- .std_error - } - res -} - #' @export -#' @rdname bart-internal #' @keywords internal dbart_predict_calc <- function(obj, new_data, type, level = 0.95, std_err = FALSE) { types <- c("numeric", "class", "prob", "conf_int", "pred_int") diff --git a/R/condense_control.R b/R/condense_control.R index a5b49bd97..d7b2f984f 100644 --- a/R/condense_control.R +++ b/R/condense_control.R @@ -10,6 +10,10 @@ #' #' @return A control object with the same elements and classes of `ref`, with #' values of `x`. +#' @param call The execution environment of a currently running function, e.g. +#' `caller_env()`. The function will be mentioned in error messages as the +#' source of the error. See the call argument of [rlang::abort()] for more +#' information. #' @keywords internal #' @export #' @@ -20,16 +24,17 @@ #' #' ctrl <- condense_control(ctrl, control_parsnip()) #' str(ctrl) -condense_control <- function(x, ref) { +condense_control <- function(x, ref, ..., call = rlang::caller_env()) { + check_dots_empty() mismatch <- setdiff(names(ref), names(x)) if (length(mismatch)) { cli::cli_abort( c( - "Object of class {.cls class(x)[1]} cannot be coerced to - object of class {.cls class(ref)[1]}.", + "{.obj_type_friendly {x}} cannot be coerced to {.obj_type_friendly {ref}}.", "i" = "{cli::qty(mismatch)} The argument{?s} {.arg {mismatch}} {?is/are} missing." - ) + ), + call = call ) } res <- x[names(ref)] diff --git a/R/contr_one_hot.R b/R/contr_one_hot.R index 7ea1115ae..00cdb3484 100644 --- a/R/contr_one_hot.R +++ b/R/contr_one_hot.R @@ -3,7 +3,8 @@ #' This contrast function produces a model matrix with indicator columns for #' each level of each factor. #' -#' @param n A vector of character factor levels or the number of unique levels. +#' @param n A vector of character factor levels (of length >=1) or the number +#' of unique levels (>= 1). #' @param contrasts This argument is for backwards compatibility and only the #' default of `TRUE` is supported. #' @param sparse This argument is for backwards compatibility and only the @@ -24,9 +25,13 @@ contr_one_hot <- function(n, contrasts = TRUE, sparse = FALSE) { } if (is.character(n)) { + if (length(n) < 1) { + cli::cli_abort("{.arg n} cannot be empty.") + } names <- n n <- length(names) } else if (is.numeric(n)) { + check_number_whole(n, min = 1) n <- as.integer(n) if (length(n) != 1L) { @@ -35,7 +40,7 @@ contr_one_hot <- function(n, contrasts = TRUE, sparse = FALSE) { names <- as.character(seq_len(n)) } else { - cli::cli_abort("{.arg n} must be a character vector or an integer of size 1.") + check_number_whole(n, min = 1) } out <- diag(n) diff --git a/R/convert_data.R b/R/convert_data.R index 6cf1e8d47..98615cab0 100644 --- a/R/convert_data.R +++ b/R/convert_data.R @@ -40,18 +40,21 @@ na.action = na.omit, indicators = "traditional", composition = "data.frame", - remove_intercept = TRUE) { + remove_intercept = TRUE, + call = rlang::caller_env()) { if (!(composition %in% c("data.frame", "matrix", "dgCMatrix"))) { cli::cli_abort( "{.arg composition} should be either {.val data.frame}, {.val matrix}, or - {.val dgCMatrix}." + {.val dgCMatrix}.", + call = call ) } if (sparsevctrs::has_sparse_elements(data)) { cli::cli_abort( - "Sparse data cannot be used with formula interface. Please use - {.fn fit_xy} instead." + "Sparse data cannot be used with formula interface. Please use + {.fn fit_xy} instead.", + call = call ) } @@ -84,7 +87,7 @@ w <- as.vector(model.weights(mod_frame)) if (!is.null(w) && !is.numeric(w)) { - cli::cli_abort("{.arg weights} must be a numeric vector.") + cli::cli_abort("{.arg weights} must be a numeric vector.", call = call) } # TODO: Do we actually use the offset when fitting? @@ -175,10 +178,12 @@ .convert_form_to_xy_new <- function(object, new_data, na.action = na.pass, - composition = "data.frame") { + composition = "data.frame", + call = rlang::caller_env()) { if (!(composition %in% c("data.frame", "matrix"))) { cli::cli_abort( - "{.arg composition} should be either {.val data.frame} or {.val matrix}." + "{.arg composition} should be either {.val data.frame} or {.val matrix}.", + call = call ) } @@ -244,9 +249,10 @@ y, weights = NULL, y_name = "..y", - remove_intercept = TRUE) { + remove_intercept = TRUE, + call = rlang::caller_env()) { if (is.vector(x)) { - cli::cli_abort("{.arg x} cannot be a vector.") + cli::cli_abort("{.arg x} cannot be a vector.", call = call) } if (remove_intercept) { @@ -279,10 +285,10 @@ if (!is.null(weights)) { if (!is.numeric(weights)) { - cli::cli_abort("{.arg weights} must be a numeric vector.") + cli::cli_abort("{.arg weights} must be a numeric vector.", call = call) } if (length(weights) != nrow(x)) { - cli::cli_abort("{.arg weights} should have {nrow(x)} elements.") + cli::cli_abort("{.arg weights} should have {nrow(x)} elements.", call = call) } form <- patch_formula_environment_with_case_weights( diff --git a/R/descriptors.R b/R/descriptors.R index 0d94d14a9..34640d949 100644 --- a/R/descriptors.R +++ b/R/descriptors.R @@ -103,22 +103,23 @@ NULL # Descriptor retrievers -------------------------------------------------------- -get_descr_form <- function(formula, data) { +get_descr_form <- function(formula, data, call = rlang::caller_env()) { if (inherits(data, "tbl_spark")) { res <- get_descr_spark(formula, data) } else { - res <- get_descr_df(formula, data) + res <- get_descr_df(formula, data, call = call) } res } -get_descr_df <- function(formula, data) { +get_descr_df <- function(formula, data, call = rlang::caller_env()) { tmp_dat <- .convert_form_to_xy_fit(formula, data, indicators = "none", - remove_intercept = TRUE) + remove_intercept = TRUE, + call = call) if(is.factor(tmp_dat$y)) { .lvls <- function() { @@ -136,7 +137,8 @@ get_descr_df <- function(formula, data) { formula, data, indicators = "traditional", - remove_intercept = TRUE + remove_intercept = TRUE, + call = call )$x ) } @@ -263,7 +265,7 @@ get_descr_spark <- function(formula, data) { ) } -get_descr_xy <- function(x, y) { +get_descr_xy <- function(x, y, call = rlang::caller_env()) { .lvls <- if (is.factor(y)) { function() table(y, dnn = NULL) @@ -291,7 +293,7 @@ get_descr_xy <- function(x, y) { } .dat <- function() { - .convert_xy_to_form_fit(x, y, remove_intercept = TRUE)$data + .convert_xy_to_form_fit(x, y, remove_intercept = TRUE, call = call)$data } .x <- function() { diff --git a/R/fit.R b/R/fit.R index 9925436da..5f7416aa9 100644 --- a/R/fit.R +++ b/R/fit.R @@ -157,7 +157,7 @@ fit.model_spec <- } if (all(c("x", "y") %in% names(dots))) { - cli::cli_abort("`fit.model_spec()` is for the formula methods. Use `fit_xy()` instead.") + cli::cli_abort("{.fn fit.model_spec} is for the formula methods. Use {.fn fit_xy} instead.") } cl <- match.call(expand.dots = TRUE) # Create an environment with the evaluated argument objects. This will be @@ -307,7 +307,8 @@ fit_xy.model_spec <- if (object$engine == "spark") { cli::cli_abort( - "spark objects can only be used with the formula interface to {.fn fit} with a spark data object." + "spark objects can only be used with the formula interface to {.fn fit} + with a spark data object." ) } diff --git a/R/fit_helpers.R b/R/fit_helpers.R index ca342b31f..168ea8e44 100644 --- a/R/fit_helpers.R +++ b/R/fit_helpers.R @@ -27,7 +27,7 @@ form_form <- # if descriptors are needed, update descr_env with the calculated values if (requires_descrs(object)) { - data_stats <- get_descr_form(env$formula, env$data) + data_stats <- get_descr_form(env$formula, env$data, call = call) scoped_descrs(data_stats) } @@ -86,7 +86,7 @@ xy_xy <- function(object, # if descriptors are needed, update descr_env with the calculated values if (requires_descrs(object)) { - data_stats <- get_descr_xy(env$x, env$y) + data_stats <- get_descr_xy(env$x, env$y, call = call) scoped_descrs(data_stats) } @@ -96,7 +96,7 @@ xy_xy <- function(object, # sub in arguments to actual syntax for corresponding engine object <- translate(object, engine = object$engine) - fit_call <- make_xy_call(object, target, env) + fit_call <- make_xy_call(object, target, env, call) res <- list(lvl = levels(env$y), spec = object) @@ -141,7 +141,8 @@ form_xy <- function(object, control, env, ..., composition = target, indicators = indicators, - remove_intercept = remove_intercept + remove_intercept = remove_intercept, + call = call ) env$x <- data_obj$x env$y <- data_obj$y diff --git a/R/glmnet-engines.R b/R/glmnet-engines.R index 296468e49..2d31bee3c 100644 --- a/R/glmnet-engines.R +++ b/R/glmnet-engines.R @@ -357,18 +357,21 @@ format_glmnet_multinom_class <- function(pred, penalty, lvl, n_obs) { #' @rdname glmnet_helpers #' @keywords internal #' @export -.check_glmnet_penalty_fit <- function(x) { +.check_glmnet_penalty_fit <- function(x, call = rlang::caller_env()) { pen <- rlang::eval_tidy(x$args$penalty) if (length(pen) != 1) { - cli::cli_abort(c( - "x" = "For the glmnet engine, {.arg penalty} must be a single number - (or a value of {.fn tune}).", - "!" = "There are {length(pen)} value{?s} for {.arg penalty}.", - "i" = "To try multiple values for total regularization, use the - {.pkg tune} package.", - "i" = "To predict multiple penalties, use {.fn multi_predict}." - )) + cli::cli_abort( + c( + "x" = "For the glmnet engine, {.arg penalty} must be a single number + (or a value of {.fn tune}).", + "!" = "There are {length(pen)} value{?s} for {.arg penalty}.", + "i" = "To try multiple values for total regularization, use the + {.pkg tune} package.", + "i" = "To predict multiple penalties, use {.fn multi_predict}." + ), + call = call + ) } } @@ -379,7 +382,8 @@ format_glmnet_multinom_class <- function(pred, penalty, lvl, n_obs) { #' @rdname glmnet_helpers #' @keywords internal #' @export -.check_glmnet_penalty_predict <- function(penalty = NULL, object, multi = FALSE) { +.check_glmnet_penalty_predict <- function(penalty = NULL, object, multi = FALSE, + call = rlang::caller_env()) { if (is.null(penalty)) { penalty <- object$fit$lambda } @@ -387,19 +391,25 @@ format_glmnet_multinom_class <- function(pred, penalty, lvl, n_obs) { # when using `predict()`, allow for a single lambda if (!multi) { if (length(penalty) != 1) { - cli::cli_abort(c( - "{.arg penalty} should be a single numeric value.", - "i" = "{.fn multi_predict} can be used to get multiple predictions per row of data." - )) + cli::cli_abort( + c( + "{.arg penalty} should be a single numeric value.", + "i" = "{.fn multi_predict} can be used to get multiple predictions per row of data." + ), + call = call + ) } } if (length(object$fit$lambda) == 1 && penalty != object$fit$lambda) { - cli::cli_abort(c( - "The glmnet model was fit with a single penalty value of + cli::cli_abort( + c( + "The glmnet model was fit with a single penalty value of {.arg object$fit$lambda}. Predicting with a value of {.arg penalty} will give incorrect results from `glmnet()`." - )) + ), + call = call + ) } penalty diff --git a/R/mars.R b/R/mars.R index 49321b3f7..6decfb8ce 100644 --- a/R/mars.R +++ b/R/mars.R @@ -161,21 +161,23 @@ multi_predict._earth <- object$fit$call[[i]] <- eval_tidy(object$fit$call[[i]]) } - msg <- - c("x" = "Please use {.code keepxy = TRUE} as an option to enable submodel + msg <- + c("x" = "Please use {.code keepxy = TRUE} as an option to enable submodel predictions with earth.") if (any(names(object$fit$call) == "keepxy")) { - if (!isTRUE(object$fit$call$keepxy)) - cli::cli_abort(msg) + if (!isTRUE(object$fit$call$keepxy)) { + cli::cli_abort(msg) + } } else { cli::cli_abort(msg) } if (is.null(type)) { - if (object$spec$mode == "classification") + if (object$spec$mode == "classification") { type <- "class" - else + } else { type <- "numeric" + } } res <- diff --git a/R/misc.R b/R/misc.R index dfec898a6..7eb22b4f9 100644 --- a/R/misc.R +++ b/R/misc.R @@ -415,21 +415,21 @@ check_outcome <- function(y, spec) { #' @export #' @keywords internal #' @rdname add_on_exports -check_final_param <- function(x) { +check_final_param <- function(x, call = rlang::caller_env()) { if (is.null(x)) { return(invisible(x)) } if (!is.list(x) & !tibble::is_tibble(x)) { - cli::cli_abort("The parameter object should be a list or tibble.") + cli::cli_abort("The parameter object should be a list or tibble.", call = call) } if (tibble::is_tibble(x) && nrow(x) > 1) { - cli::cli_abort("The parameter tibble should have a single row.") + cli::cli_abort("The parameter tibble should have a single row.", call = call) } if (tibble::is_tibble(x)) { x <- as.list(x) } if (length(names) == 0 || any(names(x) == "")) { - cli::cli_abort("All values in {.arg parameters} should have a name.") + cli::cli_abort("All values in {.arg parameters} should have a name.", call = call) } invisible(x) @@ -438,7 +438,7 @@ check_final_param <- function(x) { #' @export #' @keywords internal #' @rdname add_on_exports -update_main_parameters <- function(args, param) { +update_main_parameters <- function(args, param, call = rlang::caller_env()) { if (length(param) == 0) { return(args) } @@ -451,7 +451,8 @@ update_main_parameters <- function(args, param) { extra_args <- names(param)[has_extra_args] if (any(has_extra_args)) { cli::cli_abort( - "Argument{?s} {.arg {extra_args}} {?is/are} not a main argument." + "Argument{?s} {.arg {extra_args}} {?is/are} not a main argument.", + call = call ) } param <- param[!has_extra_args] diff --git a/R/mlp.R b/R/mlp.R index b35dbe32c..5080c555e 100644 --- a/R/mlp.R +++ b/R/mlp.R @@ -146,9 +146,10 @@ check_args.mlp <- function(object, call = rlang::caller_env()) { # keras wrapper for feed-forward nnet -class2ind <- function (x, drop2nd = FALSE) { - if (!is.factor(x)) - cli::cli_abort(c("x" = "{.arg x} should be a factor.")) +class2ind <- function (x, drop2nd = FALSE, call = rlang::caller_env()) { + if (!is.factor(x)) { + cli::cli_abort(c("x" = "{.arg x} should be a {cls factor} not {.obj_type_friendly {x}.")) + } y <- model.matrix( ~ x - 1) colnames(y) <- gsub("^x", "", colnames(y)) attributes(y)$assign <- NULL diff --git a/R/predict_class.R b/R/predict_class.R index 3c8fe69a6..d11f79f36 100644 --- a/R/predict_class.R +++ b/R/predict_class.R @@ -10,13 +10,15 @@ #' @export predict_class.model_fit <- function(object, new_data, ...) { if (object$spec$mode != "classification") { - cli::cli_abort("{.fun predict.model_fit} is for predicting factor outcomes.") + cli::cli_abort("{.fun predict.model_fit} is for predicting factor outcomes.", + call = rlang::call2("predict")) } check_spec_pred_type(object, "class") if (inherits(object$fit, "try-error")) { - cli::cli_warn("Model fit failed; cannot make predictions.") + cli::cli_warn("Model fit failed; cannot make predictions.", + call = rlang::call2("predict")) return(NULL) } diff --git a/R/predict_classprob.R b/R/predict_classprob.R index 86bea49d3..7642ae3d3 100644 --- a/R/predict_classprob.R +++ b/R/predict_classprob.R @@ -6,7 +6,8 @@ #' @export predict_classprob.model_fit <- function(object, new_data, ...) { if (object$spec$mode != "classification") { - cli::cli_abort("{.fun predict.model_fit()} is for predicting factor outcomes.") + cli::cli_abort("{.fun predict.model_fit()} is for predicting factor outcomes.", + call = rlang::call2("predict")) } check_spec_pred_type(object, "prob", call = caller_env()) @@ -36,7 +37,8 @@ predict_classprob.model_fit <- function(object, new_data, ...) { # check and sort names if (!is.data.frame(res) & !inherits(res, "tbl_spark")) { - cli::cli_abort("The was a problem with the probability predictions.") + cli::cli_abort("The was a problem with the probability predictions.", + call = rlang::call2("predict")) } if (!is_tibble(res) & !inherits(res, "tbl_spark")) { diff --git a/R/predict_numeric.R b/R/predict_numeric.R index 6f8aed916..509c39ef0 100644 --- a/R/predict_numeric.R +++ b/R/predict_numeric.R @@ -11,7 +11,8 @@ predict_numeric.model_fit <- function(object, new_data, ...) { "{.fun predict_numeric} is for predicting numeric outcomes.", "i" = "Use {.fun predict_class} or {.fun predict_classprob} for classification models." - ) + ), + call = rlang::call2("predict") ) } diff --git a/R/predict_quantile.R b/R/predict_quantile.R index 6a8b5060b..54d6461ff 100644 --- a/R/predict_quantile.R +++ b/R/predict_quantile.R @@ -37,7 +37,8 @@ predict_quantile.model_fit <- function(object, if (object$spec$mode == "quantile regression") { if (!is.null(quantile_levels)) { cli::cli_abort("When the mode is {.val quantile regression}, - {.arg quantile_levels} are specified by {.fn set_mode}.") + {.arg quantile_levels} are specified by {.fn set_mode}.", + call = rlang::call2("predict")) } } else { if (is.null(quantile_levels)) { diff --git a/R/predict_time.R b/R/predict_time.R index 769b7a578..616cb29aa 100644 --- a/R/predict_time.R +++ b/R/predict_time.R @@ -11,7 +11,8 @@ predict_time.model_fit <- function(object, new_data, ...) { "{.fun predict_time} is for predicting time outcomes.", "i" = "Use {.fun predict_class} or {.fun predict_classprob} for classification models." - ) + ), + call = rlang::call2("predict") ) } diff --git a/R/rand_forest_data.R b/R/rand_forest_data.R index ef548dda3..598bceaa9 100644 --- a/R/rand_forest_data.R +++ b/R/rand_forest_data.R @@ -66,7 +66,8 @@ ranger_confint <- function(object, new_data, ...) { } else { cli::cli_abort( "Cannot compute confidence intervals for a ranger forest - of type {.val {object$fit$forest$treetype}}." + of type {.val {object$fit$forest$treetype}}.", + call = rlang::call2("predict") ) } } diff --git a/R/translate.R b/R/translate.R index eb2ef7158..6c0fc2881 100644 --- a/R/translate.R +++ b/R/translate.R @@ -176,6 +176,7 @@ add_methods <- function(x, engine) { #' or in a list that can facilitate renaming grid objects? #' @return A tibble with columns `user`, `parsnip`, and `engine`, or a list #' with named character vectors `user_to_parsnip` and `parsnip_to_engine`. +#' @keywords internal #' @examplesIf !parsnip:::is_cran_check() #' mod <- #' linear_reg(penalty = tune("regularization"), mixture = tune()) %>% diff --git a/R/update.R b/R/update.R index 3dd890fb4..02cf1d53f 100644 --- a/R/update.R +++ b/R/update.R @@ -64,10 +64,10 @@ update_spec <- function(object, parameters, args_enquo_list, fresh, cls, ..., eng_args <- update_engine_parameters(object$eng_args, fresh, ...) if (!is.null(parameters)) { - parameters <- check_final_param(parameters) + parameters <- check_final_param(parameters, call = call) } - args <- update_main_parameters(args_enquo_list, parameters) + args <- update_main_parameters(args_enquo_list, parameters, call = call) if (fresh) { object$args <- args diff --git a/man/add_on_exports.Rd b/man/add_on_exports.Rd index 2f1d62a14..09c3434fb 100644 --- a/man/add_on_exports.Rd +++ b/man/add_on_exports.Rd @@ -34,9 +34,9 @@ new_model_spec( user_specified_engine = TRUE ) -check_final_param(x) +check_final_param(x, call = rlang::caller_env()) -update_main_parameters(args, param) +update_main_parameters(args, param, call = rlang::caller_env()) update_engine_parameters(eng_args, fresh, ...) diff --git a/man/bart-internal.Rd b/man/bart-internal.Rd index 2665bd380..4056cd69f 100644 --- a/man/bart-internal.Rd +++ b/man/bart-internal.Rd @@ -2,28 +2,23 @@ % Please edit documentation in R/bart.R \name{bart-internal} \alias{bart-internal} -\alias{bartMachine_interval_calc} \alias{dbart_predict_calc} \title{Developer functions for predictions via BART models} \usage{ -bartMachine_interval_calc(new_data, obj, ci = TRUE, level = 0.95) - dbart_predict_calc(obj, new_data, type, level = 0.95, std_err = FALSE) } \arguments{ -\item{new_data}{A rectangular data object, such as a data frame.} - \item{obj}{A parsnip object.} -\item{ci}{Confidence (TRUE) or prediction interval (FALSE)} - -\item{level}{Confidence level.} +\item{new_data}{A rectangular data object, such as a data frame.} \item{type}{A single character value or \code{NULL}. Possible values are \code{"numeric"}, \code{"class"}, \code{"prob"}, \code{"conf_int"}, \code{"pred_int"}, \code{"quantile"}, \code{"time"}, \code{"hazard"}, \code{"survival"}, or \code{"raw"}. When \code{NULL}, \code{predict()} will choose an appropriate value based on the model's mode.} +\item{level}{Confidence level.} + \item{std_err}{Attach column for standard error of prediction or not.} } \description{ diff --git a/man/condense_control.Rd b/man/condense_control.Rd index 326dae09b..d347bcd3e 100644 --- a/man/condense_control.Rd +++ b/man/condense_control.Rd @@ -4,13 +4,18 @@ \alias{condense_control} \title{Condense control object into strictly smaller control object} \usage{ -condense_control(x, ref) +condense_control(x, ref, ..., call = rlang::caller_env()) } \arguments{ \item{x}{A control object to be condensed.} \item{ref}{A control object that is used to determine what element should be kept.} + +\item{call}{The execution environment of a currently running function, e.g. +\code{caller_env()}. The function will be mentioned in error messages as the +source of the error. See the call argument of \code{\link[rlang:abort]{rlang::abort()}} for more +information.} } \value{ A control object with the same elements and classes of \code{ref}, with diff --git a/man/contr_one_hot.Rd b/man/contr_one_hot.Rd index f57e589fc..57ff6654c 100644 --- a/man/contr_one_hot.Rd +++ b/man/contr_one_hot.Rd @@ -7,7 +7,8 @@ contr_one_hot(n, contrasts = TRUE, sparse = FALSE) } \arguments{ -\item{n}{A vector of character factor levels or the number of unique levels.} +\item{n}{A vector of character factor levels (of length >=1) or the number +of unique levels (>= 1).} \item{contrasts}{This argument is for backwards compatibility and only the default of \code{TRUE} is supported.} diff --git a/man/convert_helpers.Rd b/man/convert_helpers.Rd index d25f830bd..583e91b46 100644 --- a/man/convert_helpers.Rd +++ b/man/convert_helpers.Rd @@ -14,14 +14,16 @@ na.action = na.omit, indicators = "traditional", composition = "data.frame", - remove_intercept = TRUE + remove_intercept = TRUE, + call = rlang::caller_env() ) .convert_form_to_xy_new( object, new_data, na.action = na.pass, - composition = "data.frame" + composition = "data.frame", + call = rlang::caller_env() ) .convert_xy_to_form_fit( @@ -29,7 +31,8 @@ y, weights = NULL, y_name = "..y", - remove_intercept = TRUE + remove_intercept = TRUE, + call = rlang::caller_env() ) .convert_xy_to_form_new(object, new_data) diff --git a/man/dot-model_param_name_key.Rd b/man/dot-model_param_name_key.Rd index 31372ae8e..cfd1640d8 100644 --- a/man/dot-model_param_name_key.Rd +++ b/man/dot-model_param_name_key.Rd @@ -42,3 +42,4 @@ grid \%>\% dplyr::rename(!!!rn$parsnip_to_engine) \dontshow{\}) # examplesIf} } +\keyword{internal} diff --git a/man/glmnet_helpers.Rd b/man/glmnet_helpers.Rd index e6fcfc3aa..246fb1f20 100644 --- a/man/glmnet_helpers.Rd +++ b/man/glmnet_helpers.Rd @@ -5,9 +5,14 @@ \alias{.check_glmnet_penalty_predict} \title{Helper functions for checking the penalty of glmnet models} \usage{ -.check_glmnet_penalty_fit(x) +.check_glmnet_penalty_fit(x, call = rlang::caller_env()) -.check_glmnet_penalty_predict(penalty = NULL, object, multi = FALSE) +.check_glmnet_penalty_predict( + penalty = NULL, + object, + multi = FALSE, + call = rlang::caller_env() +) } \arguments{ \item{x}{An object of class \code{model_spec}.} diff --git a/tests/testthat/_snaps/condense_control.md b/tests/testthat/_snaps/condense_control.md index 19f908d21..7e9277238 100644 --- a/tests/testthat/_snaps/condense_control.md +++ b/tests/testthat/_snaps/condense_control.md @@ -3,7 +3,16 @@ Code condense_control(control_parsnip(), ctrl) Condition - Error in `condense_control()`: - ! Object of class cannot be coerced to object of class . + Error: + ! a object cannot be coerced to a object. + i The arguments `allow_par` and `anotherone` are missing. + +--- + + Code + control_test(ctrl) + Condition + Error in `control_test()`: + ! a object cannot be coerced to a object. i The arguments `allow_par` and `anotherone` are missing. diff --git a/tests/testthat/_snaps/contr_one_hot.md b/tests/testthat/_snaps/contr_one_hot.md new file mode 100644 index 000000000..f677fd5a6 --- /dev/null +++ b/tests/testthat/_snaps/contr_one_hot.md @@ -0,0 +1,48 @@ +# one-hot encoding contrasts + + Code + contr_one_hot(character(0)) + Condition + Error in `contr_one_hot()`: + ! `n` cannot be empty. + +--- + + Code + contr_one_hot(-1) + Condition + Error in `contr_one_hot()`: + ! `n` must be a whole number larger than or equal to 1, not the number -1. + +--- + + Code + contr_one_hot(list()) + Condition + Error in `contr_one_hot()`: + ! `n` must be a whole number, not an empty list. + +--- + + Code + contr_one_hot(2, contrast = FALSE) + Condition + Warning: + `contrasts = FALSE` not implemented for `contr_one_hot()`. + Output + 1 2 + 1 1 0 + 2 0 1 + +--- + + Code + contr_one_hot(2, sparse = TRUE) + Condition + Warning: + `sparse = TRUE` not implemented for `contr_one_hot()`. + Output + 1 2 + 1 1 0 + 2 0 1 + diff --git a/tests/testthat/_snaps/convert_data.md b/tests/testthat/_snaps/convert_data.md index bc4c5606b..d864f282a 100644 --- a/tests/testthat/_snaps/convert_data.md +++ b/tests/testthat/_snaps/convert_data.md @@ -21,7 +21,7 @@ .convert_form_to_xy_fit(mpg ~ ., data = mtcars, composition = "tibble", indicators = "traditional", remove_intercept = TRUE) Condition - Error in `.convert_form_to_xy_fit()`: + Error: ! `composition` should be either "data.frame", "matrix", or "dgCMatrix". --- @@ -30,7 +30,7 @@ .convert_form_to_xy_fit(mpg ~ ., data = mtcars, weights = letters[1:nrow(mtcars)], indicators = "traditional", remove_intercept = TRUE) Condition - Error in `.convert_form_to_xy_fit()`: + Error: ! `weights` must be a numeric vector. --- @@ -38,7 +38,7 @@ Code .convert_xy_to_form_fit(mtcars$disp, mtcars$mpg, remove_intercept = TRUE) Condition - Error in `.convert_xy_to_form_fit()`: + Error: ! `x` cannot be a vector. --- diff --git a/tests/testthat/_snaps/linear_reg_quantreg.md b/tests/testthat/_snaps/linear_reg_quantreg.md index cba265991..11fbd80e2 100644 --- a/tests/testthat/_snaps/linear_reg_quantreg.md +++ b/tests/testthat/_snaps/linear_reg_quantreg.md @@ -4,6 +4,6 @@ ten_quant_pred <- predict(ten_quant, new_data = sac_test, quantile_levels = (0: 9) / 9) Condition - Error in `predict_quantile()`: + Error in `predict()`: ! When the mode is "quantile regression", `quantile_levels` are specified by `set_mode()`. diff --git a/tests/testthat/_snaps/misc.md b/tests/testthat/_snaps/misc.md index 0bbae37cf..b6b1f918c 100644 --- a/tests/testthat/_snaps/misc.md +++ b/tests/testthat/_snaps/misc.md @@ -1,5 +1,13 @@ # parsnip objects + Code + predict(lm_idea, mtcars) + Condition + Error in `predict()`: + ! You must `fit()` your model specification (`?parsnip::model_spec()`) before you can use `predict()`. + +--- + Code multi_predict(lm_fit, mtcars) Condition diff --git a/tests/testthat/_snaps/sparsevctrs.md b/tests/testthat/_snaps/sparsevctrs.md index 006fc82b8..30c84b1a9 100644 --- a/tests/testthat/_snaps/sparsevctrs.md +++ b/tests/testthat/_snaps/sparsevctrs.md @@ -3,7 +3,7 @@ Code xgb_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data) Condition - Error in `.convert_form_to_xy_fit()`: + Error in `fit()`: ! Sparse data cannot be used with formula interface. Please use `fit_xy()` instead. # sparse tibble can be passed to `fit() - unsupported @@ -19,7 +19,7 @@ Code xgb_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data) Condition - Error in `.convert_form_to_xy_fit()`: + Error in `fit()`: ! Sparse data cannot be used with formula interface. Please use `fit_xy()` instead. # sparse matrix can be passed to `fit() - unsupported @@ -67,7 +67,7 @@ Code xgb_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data) Condition - Error in `.convert_form_to_xy_fit()`: + Error in `fit()`: ! Sparse data cannot be used with formula interface. Please use `fit_xy()` instead. # to_sparse_data_frame() is used correctly diff --git a/tests/testthat/_snaps/translate.md b/tests/testthat/_snaps/translate.md index edb00ca38..b7e2488f3 100644 --- a/tests/testthat/_snaps/translate.md +++ b/tests/testthat/_snaps/translate.md @@ -503,7 +503,7 @@ Code translate_args(basic %>% set_engine("glmnet")) Condition - Error in `.check_glmnet_penalty_fit()`: + Error in `translate()`: x For the glmnet engine, `penalty` must be a single number (or a value of `tune()`). ! There are 0 values for `penalty`. i To try multiple values for total regularization, use the tune package. @@ -554,7 +554,7 @@ Code translate_args(mixture %>% set_engine("glmnet")) Condition - Error in `.check_glmnet_penalty_fit()`: + Error in `translate()`: x For the glmnet engine, `penalty` must be a single number (or a value of `tune()`). ! There are 0 values for `penalty`. i To try multiple values for total regularization, use the tune package. @@ -687,7 +687,7 @@ Code translate_args(basic %>% set_engine("glmnet")) Condition - Error in `.check_glmnet_penalty_fit()`: + Error in `translate()`: x For the glmnet engine, `penalty` must be a single number (or a value of `tune()`). ! There are 0 values for `penalty`. i To try multiple values for total regularization, use the tune package. @@ -826,7 +826,7 @@ Code translate_args(mixture %>% set_engine("glmnet")) Condition - Error in `.check_glmnet_penalty_fit()`: + Error in `translate()`: x For the glmnet engine, `penalty` must be a single number (or a value of `tune()`). ! There are 0 values for `penalty`. i To try multiple values for total regularization, use the tune package. @@ -967,7 +967,7 @@ Code translate_args(mixture_v %>% set_engine("glmnet")) Condition - Error in `.check_glmnet_penalty_fit()`: + Error in `translate()`: x For the glmnet engine, `penalty` must be a single number (or a value of `tune()`). ! There are 0 values for `penalty`. i To try multiple values for total regularization, use the tune package. @@ -1333,7 +1333,7 @@ Code translate_args(basic %>% set_engine("glmnet")) Condition - Error in `.check_glmnet_penalty_fit()`: + Error in `translate()`: x For the glmnet engine, `penalty` must be a single number (or a value of `tune()`). ! There are 0 values for `penalty`. i To try multiple values for total regularization, use the tune package. @@ -1551,7 +1551,7 @@ Code basic_incomplete %>% translate_args() Condition - Error in `.check_glmnet_penalty_fit()`: + Error in `translate()`: x For the glmnet engine, `penalty` must be a single number (or a value of `tune()`). ! There are 0 values for `penalty`. i To try multiple values for total regularization, use the tune package. diff --git a/tests/testthat/_snaps/update.md b/tests/testthat/_snaps/update.md index 394aa2a9f..be0b23487 100644 --- a/tests/testthat/_snaps/update.md +++ b/tests/testthat/_snaps/update.md @@ -196,7 +196,7 @@ Code expr1 %>% update(param_tibb) Condition - Error in `update_main_parameters()`: + Error in `update()`: ! Argument `nlambda` is not a main argument. --- @@ -204,7 +204,7 @@ Code expr1 %>% update(param_list) Condition - Error in `update_main_parameters()`: + Error in `update()`: ! Argument `nlambda` is not a main argument. --- @@ -212,7 +212,7 @@ Code expr1 %>% update(parameters = "wat") Condition - Error in `check_final_param()`: + Error in `update()`: ! The parameter object should be a list or tibble. --- @@ -220,7 +220,7 @@ Code expr1 %>% update(parameters = tibble::tibble(wat = "wat")) Condition - Error in `update_main_parameters()`: + Error in `update()`: ! Argument `wat` is not a main argument. --- diff --git a/tests/testthat/test-condense_control.R b/tests/testthat/test-condense_control.R index 4da0c2d57..345bac5dc 100644 --- a/tests/testthat/test-condense_control.R +++ b/tests/testthat/test-condense_control.R @@ -18,4 +18,11 @@ test_that("condense_control works", { expect_snapshot(error = TRUE, condense_control(control_parsnip(), ctrl) ) + + # Emulate being called from one of the upstream control_* functions + control_test <- function(control = control_parsnip()) { + control <- parsnip::condense_control(control_parsnip(), control) + invisible(control) + } + expect_snapshot(error = TRUE, control_test(ctrl)) }) diff --git a/tests/testthat/test-contr_one_hot.R b/tests/testthat/test-contr_one_hot.R new file mode 100644 index 000000000..3f1d78857 --- /dev/null +++ b/tests/testthat/test-contr_one_hot.R @@ -0,0 +1,19 @@ +test_that('one-hot encoding contrasts', { + contr_mat <- contr_one_hot(12) + expect_equal(colnames(contr_mat), paste(1:12)) + expect_equal(rownames(contr_mat), paste(1:12)) + expect_true(all(apply(contr_mat, 1, sum) == 1)) + expect_true(all(apply(contr_mat, 2, sum) == 1)) + + chr_contr_mat <- contr_one_hot(letters[1:12]) + expect_equal(colnames(chr_contr_mat), letters[1:12]) + expect_equal(rownames(chr_contr_mat), letters[1:12]) + expect_true(all(apply(chr_contr_mat, 1, sum) == 1)) + expect_true(all(apply(chr_contr_mat, 2, sum) == 1)) + + expect_snapshot(contr_one_hot(character(0)), error = TRUE) + expect_snapshot(contr_one_hot(-1), error = TRUE) + expect_snapshot(contr_one_hot(list()), error = TRUE) + expect_snapshot(contr_one_hot(2, contrast = FALSE)) + expect_snapshot(contr_one_hot(2, sparse = TRUE)) +}) diff --git a/tests/testthat/test-misc.R b/tests/testthat/test-misc.R index af18a4a7d..e689bcbcf 100644 --- a/tests/testthat/test-misc.R +++ b/tests/testthat/test-misc.R @@ -5,6 +5,7 @@ test_that('parsnip objects', { lm_idea <- linear_reg() %>% set_engine("lm") expect_false(has_multi_predict(lm_idea)) + expect_snapshot(error = TRUE, predict(lm_idea, mtcars)) lm_fit <- fit(lm_idea, mpg ~ ., data = mtcars) expect_false(has_multi_predict(lm_fit))