diff --git a/R-package/DESCRIPTION b/R-package/DESCRIPTION index 7c01d50c6811..bbaf3e75da4e 100644 --- a/R-package/DESCRIPTION +++ b/R-package/DESCRIPTION @@ -65,6 +65,6 @@ Imports: data.table (>= 1.9.6), jsonlite (>= 1.0) Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.3 +RoxygenNote: 7.3.0 Encoding: UTF-8 SystemRequirements: GNU make, C++17 diff --git a/R-package/R/xgb.importance.R b/R-package/R/xgb.importance.R index 44f2eb9b3bf6..547d9677b798 100644 --- a/R-package/R/xgb.importance.R +++ b/R-package/R/xgb.importance.R @@ -113,19 +113,12 @@ #' xgb.importance(model = mbst) #' #' @export -xgb.importance <- function(feature_names = NULL, model = NULL, trees = NULL, +xgb.importance <- function(model = NULL, feature_names = getinfo(model, "feature_name"), trees = NULL, data = NULL, label = NULL, target = NULL) { if (!(is.null(data) && is.null(label) && is.null(target))) warning("xgb.importance: parameters 'data', 'label' and 'target' are deprecated") - if (is.null(feature_names)) { - model_feature_names <- xgb.feature_names(model) - if (NROW(model_feature_names)) { - feature_names <- model_feature_names - } - } - if (!(is.null(feature_names) || is.character(feature_names))) stop("feature_names: Has to be a character vector") diff --git a/R-package/R/xgb.model.dt.tree.R b/R-package/R/xgb.model.dt.tree.R index df0e672a92cd..ff416b73e38a 100644 --- a/R-package/R/xgb.model.dt.tree.R +++ b/R-package/R/xgb.model.dt.tree.R @@ -2,11 +2,8 @@ #' #' Parse a boosted tree model text dump into a `data.table` structure. #' -#' @param feature_names Character vector of feature names. If the model already -#' contains feature names, those will be used when \code{feature_names=NULL} (default value). -#' -#' Note that, if the model already contains feature names, it's \bold{not} possible to override them here. -#' @param model Object of class `xgb.Booster`. +#' @param model Object of class `xgb.Booster`. If it contains feature names (they can be set through +#' \link{setinfo}), they will be used in the output from this function. #' @param text Character vector previously generated by the function [xgb.dump()] #' (called with parameter `with_stats = TRUE`). `text` takes precedence over `model`. #' @param trees An integer vector of tree indices that should be used. @@ -58,7 +55,7 @@ #' #' # This bst model already has feature_names stored with it, so those would be used when #' # feature_names is not set: -#' (dt <- xgb.model.dt.tree(model = bst)) +#' dt <- xgb.model.dt.tree(bst) #' #' # How to match feature names of splits that are following a current 'Yes' branch: #' merge( @@ -69,7 +66,7 @@ #' ] #' #' @export -xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL, +xgb.model.dt.tree <- function(model = NULL, text = NULL, trees = NULL, use_int_id = FALSE, ...) { check.deprecation(...) @@ -79,24 +76,15 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL, " (or NULL if 'model' was provided).") } - model_feature_names <- NULL - if (inherits(model, "xgb.Booster")) { - model_feature_names <- xgb.feature_names(model) - if (NROW(model_feature_names) && !is.null(feature_names)) { - stop("'model' contains feature names. Cannot override them.") - } - } - if (is.null(feature_names) && !is.null(model) && !is.null(model_feature_names)) - feature_names <- model_feature_names - - if (!(is.null(feature_names) || is.character(feature_names))) { - stop("feature_names: must be a character vector") - } - if (!(is.null(trees) || is.numeric(trees))) { stop("trees: must be a vector of integers.") } + feature_names <- NULL + if (inherits(model, "xgb.Booster")) { + feature_names <- xgb.feature_names(model) + } + from_text <- TRUE if (is.null(text)) { text <- xgb.dump(model = model, with_stats = TRUE) @@ -134,7 +122,7 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL, branch_rx_w_names <- paste0("\\d+:\\[(.+)<(", anynumber_regex, ")\\] yes=(\\d+),no=(\\d+),missing=(\\d+),", "gain=(", anynumber_regex, "),cover=(", anynumber_regex, ")") text_has_feature_names <- FALSE - if (NROW(model_feature_names)) { + if (NROW(feature_names)) { branch_rx <- branch_rx_w_names text_has_feature_names <- TRUE } else { @@ -148,9 +136,6 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL, } } } - if (text_has_feature_names && is.null(model) && !is.null(feature_names)) { - stop("'text' contains feature names. Cannot override them.") - } branch_cols <- c("Feature", "Split", "Yes", "No", "Missing", "Gain", "Cover") td[ isLeaf == FALSE, diff --git a/R-package/R/xgb.plot.multi.trees.R b/R-package/R/xgb.plot.multi.trees.R index 88616cfb7173..e6d678ee7a4f 100644 --- a/R-package/R/xgb.plot.multi.trees.R +++ b/R-package/R/xgb.plot.multi.trees.R @@ -62,13 +62,13 @@ #' } #' #' @export -xgb.plot.multi.trees <- function(model, feature_names = NULL, features_keep = 5, plot_width = NULL, plot_height = NULL, +xgb.plot.multi.trees <- function(model, features_keep = 5, plot_width = NULL, plot_height = NULL, render = TRUE, ...) { if (!requireNamespace("DiagrammeR", quietly = TRUE)) { stop("DiagrammeR is required for xgb.plot.multi.trees") } check.deprecation(...) - tree.matrix <- xgb.model.dt.tree(feature_names = feature_names, model = model) + tree.matrix <- xgb.model.dt.tree(model = model) # first number of the path represents the tree, then the following numbers are related to the path to follow # root init diff --git a/R-package/R/xgb.plot.tree.R b/R-package/R/xgb.plot.tree.R index c75a42e84bd7..5ed1e70f695a 100644 --- a/R-package/R/xgb.plot.tree.R +++ b/R-package/R/xgb.plot.tree.R @@ -2,9 +2,8 @@ #' #' Read a tree model text dump and plot the model. #' -#' @param feature_names Character vector used to overwrite the feature names -#' of the model. The default (`NULL`) uses the original feature names. -#' @param model Object of class `xgb.Booster`. +#' @param model Object of class `xgb.Booster`. If it contains feature names (they can be set through +#' \link{setinfo}), they will be used in the output from this function. #' @param trees An integer vector of tree indices that should be used. #' The default (`NULL`) uses all trees. #' Useful, e.g., in multiclass classification to get only @@ -103,7 +102,7 @@ #' } #' #' @export -xgb.plot.tree <- function(feature_names = NULL, model = NULL, trees = NULL, plot_width = NULL, plot_height = NULL, +xgb.plot.tree <- function(model = NULL, trees = NULL, plot_width = NULL, plot_height = NULL, render = TRUE, show_node_id = FALSE, style = c("R", "xgboost"), ...) { check.deprecation(...) if (!inherits(model, "xgb.Booster")) { @@ -120,17 +119,12 @@ xgb.plot.tree <- function(feature_names = NULL, model = NULL, trees = NULL, plot if (NROW(trees) != 1L || !render || show_node_id) { stop("style='xgboost' is only supported for single, rendered tree, without node IDs.") } - if (!is.null(feature_names)) { - stop( - "style='xgboost' cannot override 'feature_names'. Will automatically take them from the model." - ) - } txt <- xgb.dump(model, dump_format = "dot") return(DiagrammeR::grViz(txt[[trees + 1]], width = plot_width, height = plot_height)) } - dt <- xgb.model.dt.tree(feature_names = feature_names, model = model, trees = trees) + dt <- xgb.model.dt.tree(model = model, trees = trees) dt[, label := paste0(Feature, "\nCover: ", Cover, ifelse(Feature == "Leaf", "\nValue: ", "\nGain: "), Gain)] if (show_node_id) diff --git a/R-package/man/xgb.importance.Rd b/R-package/man/xgb.importance.Rd index fca1b70c46e3..73b91e8b4b28 100644 --- a/R-package/man/xgb.importance.Rd +++ b/R-package/man/xgb.importance.Rd @@ -5,8 +5,8 @@ \title{Feature importance} \usage{ xgb.importance( - feature_names = NULL, model = NULL, + feature_names = getinfo(model, "feature_name"), trees = NULL, data = NULL, label = NULL, @@ -14,11 +14,11 @@ xgb.importance( ) } \arguments{ +\item{model}{Object of class \code{xgb.Booster}.} + \item{feature_names}{Character vector used to overwrite the feature names of the model. The default is \code{NULL} (use original feature names).} -\item{model}{Object of class \code{xgb.Booster}.} - \item{trees}{An integer vector of tree indices that should be included into the importance calculation (only for the "gbtree" booster). The default (\code{NULL}) parses all trees. diff --git a/R-package/man/xgb.model.dt.tree.Rd b/R-package/man/xgb.model.dt.tree.Rd index e63bd4b10ac2..75f1cd0f4f77 100644 --- a/R-package/man/xgb.model.dt.tree.Rd +++ b/R-package/man/xgb.model.dt.tree.Rd @@ -5,7 +5,6 @@ \title{Parse model text dump} \usage{ xgb.model.dt.tree( - feature_names = NULL, model = NULL, text = NULL, trees = NULL, @@ -14,13 +13,8 @@ xgb.model.dt.tree( ) } \arguments{ -\item{feature_names}{Character vector of feature names. If the model already -contains feature names, those will be used when \code{feature_names=NULL} (default value). - -\if{html}{\out{