From 547abb8c126991e0fc24219616e1e7298e266723 Mon Sep 17 00:00:00 2001 From: david-cortes Date: Mon, 15 Jan 2024 10:16:30 +0100 Subject: [PATCH] [R] Remove unusable 'feature_names' argument and make 'model' first argument in inspection functions (#9939) --- R-package/DESCRIPTION | 2 +- R-package/R/xgb.importance.R | 9 +------ R-package/R/xgb.model.dt.tree.R | 35 +++++++------------------ R-package/R/xgb.plot.multi.trees.R | 4 +-- R-package/R/xgb.plot.tree.R | 14 +++------- R-package/man/xgb.importance.Rd | 6 ++--- R-package/man/xgb.model.dt.tree.Rd | 12 +++------ R-package/man/xgb.plot.multi.trees.Rd | 7 ++--- R-package/man/xgb.plot.tree.Rd | 7 ++--- R-package/tests/testthat/test_helpers.R | 8 ++---- 10 files changed, 30 insertions(+), 74 deletions(-) diff --git a/R-package/DESCRIPTION b/R-package/DESCRIPTION index 7c01d50c6811..bbaf3e75da4e 100644 --- a/R-package/DESCRIPTION +++ b/R-package/DESCRIPTION @@ -65,6 +65,6 @@ Imports: data.table (>= 1.9.6), jsonlite (>= 1.0) Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.3 +RoxygenNote: 7.3.0 Encoding: UTF-8 SystemRequirements: GNU make, C++17 diff --git a/R-package/R/xgb.importance.R b/R-package/R/xgb.importance.R index 44f2eb9b3bf6..547d9677b798 100644 --- a/R-package/R/xgb.importance.R +++ b/R-package/R/xgb.importance.R @@ -113,19 +113,12 @@ #' xgb.importance(model = mbst) #' #' @export -xgb.importance <- function(feature_names = NULL, model = NULL, trees = NULL, +xgb.importance <- function(model = NULL, feature_names = getinfo(model, "feature_name"), trees = NULL, data = NULL, label = NULL, target = NULL) { if (!(is.null(data) && is.null(label) && is.null(target))) warning("xgb.importance: parameters 'data', 'label' and 'target' are deprecated") - if (is.null(feature_names)) { - model_feature_names <- xgb.feature_names(model) - if (NROW(model_feature_names)) { - feature_names <- model_feature_names - } - } - if (!(is.null(feature_names) || is.character(feature_names))) stop("feature_names: Has to be a character vector") diff --git a/R-package/R/xgb.model.dt.tree.R b/R-package/R/xgb.model.dt.tree.R index df0e672a92cd..ff416b73e38a 100644 --- a/R-package/R/xgb.model.dt.tree.R +++ b/R-package/R/xgb.model.dt.tree.R @@ -2,11 +2,8 @@ #' #' Parse a boosted tree model text dump into a `data.table` structure. #' -#' @param feature_names Character vector of feature names. If the model already -#' contains feature names, those will be used when \code{feature_names=NULL} (default value). -#' -#' Note that, if the model already contains feature names, it's \bold{not} possible to override them here. -#' @param model Object of class `xgb.Booster`. +#' @param model Object of class `xgb.Booster`. If it contains feature names (they can be set through +#' \link{setinfo}), they will be used in the output from this function. #' @param text Character vector previously generated by the function [xgb.dump()] #' (called with parameter `with_stats = TRUE`). `text` takes precedence over `model`. #' @param trees An integer vector of tree indices that should be used. @@ -58,7 +55,7 @@ #' #' # This bst model already has feature_names stored with it, so those would be used when #' # feature_names is not set: -#' (dt <- xgb.model.dt.tree(model = bst)) +#' dt <- xgb.model.dt.tree(bst) #' #' # How to match feature names of splits that are following a current 'Yes' branch: #' merge( @@ -69,7 +66,7 @@ #' ] #' #' @export -xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL, +xgb.model.dt.tree <- function(model = NULL, text = NULL, trees = NULL, use_int_id = FALSE, ...) { check.deprecation(...) @@ -79,24 +76,15 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL, " (or NULL if 'model' was provided).") } - model_feature_names <- NULL - if (inherits(model, "xgb.Booster")) { - model_feature_names <- xgb.feature_names(model) - if (NROW(model_feature_names) && !is.null(feature_names)) { - stop("'model' contains feature names. Cannot override them.") - } - } - if (is.null(feature_names) && !is.null(model) && !is.null(model_feature_names)) - feature_names <- model_feature_names - - if (!(is.null(feature_names) || is.character(feature_names))) { - stop("feature_names: must be a character vector") - } - if (!(is.null(trees) || is.numeric(trees))) { stop("trees: must be a vector of integers.") } + feature_names <- NULL + if (inherits(model, "xgb.Booster")) { + feature_names <- xgb.feature_names(model) + } + from_text <- TRUE if (is.null(text)) { text <- xgb.dump(model = model, with_stats = TRUE) @@ -134,7 +122,7 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL, branch_rx_w_names <- paste0("\\d+:\\[(.+)<(", anynumber_regex, ")\\] yes=(\\d+),no=(\\d+),missing=(\\d+),", "gain=(", anynumber_regex, "),cover=(", anynumber_regex, ")") text_has_feature_names <- FALSE - if (NROW(model_feature_names)) { + if (NROW(feature_names)) { branch_rx <- branch_rx_w_names text_has_feature_names <- TRUE } else { @@ -148,9 +136,6 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL, } } } - if (text_has_feature_names && is.null(model) && !is.null(feature_names)) { - stop("'text' contains feature names. Cannot override them.") - } branch_cols <- c("Feature", "Split", "Yes", "No", "Missing", "Gain", "Cover") td[ isLeaf == FALSE, diff --git a/R-package/R/xgb.plot.multi.trees.R b/R-package/R/xgb.plot.multi.trees.R index 88616cfb7173..e6d678ee7a4f 100644 --- a/R-package/R/xgb.plot.multi.trees.R +++ b/R-package/R/xgb.plot.multi.trees.R @@ -62,13 +62,13 @@ #' } #' #' @export -xgb.plot.multi.trees <- function(model, feature_names = NULL, features_keep = 5, plot_width = NULL, plot_height = NULL, +xgb.plot.multi.trees <- function(model, features_keep = 5, plot_width = NULL, plot_height = NULL, render = TRUE, ...) { if (!requireNamespace("DiagrammeR", quietly = TRUE)) { stop("DiagrammeR is required for xgb.plot.multi.trees") } check.deprecation(...) - tree.matrix <- xgb.model.dt.tree(feature_names = feature_names, model = model) + tree.matrix <- xgb.model.dt.tree(model = model) # first number of the path represents the tree, then the following numbers are related to the path to follow # root init diff --git a/R-package/R/xgb.plot.tree.R b/R-package/R/xgb.plot.tree.R index c75a42e84bd7..5ed1e70f695a 100644 --- a/R-package/R/xgb.plot.tree.R +++ b/R-package/R/xgb.plot.tree.R @@ -2,9 +2,8 @@ #' #' Read a tree model text dump and plot the model. #' -#' @param feature_names Character vector used to overwrite the feature names -#' of the model. The default (`NULL`) uses the original feature names. -#' @param model Object of class `xgb.Booster`. +#' @param model Object of class `xgb.Booster`. If it contains feature names (they can be set through +#' \link{setinfo}), they will be used in the output from this function. #' @param trees An integer vector of tree indices that should be used. #' The default (`NULL`) uses all trees. #' Useful, e.g., in multiclass classification to get only @@ -103,7 +102,7 @@ #' } #' #' @export -xgb.plot.tree <- function(feature_names = NULL, model = NULL, trees = NULL, plot_width = NULL, plot_height = NULL, +xgb.plot.tree <- function(model = NULL, trees = NULL, plot_width = NULL, plot_height = NULL, render = TRUE, show_node_id = FALSE, style = c("R", "xgboost"), ...) { check.deprecation(...) if (!inherits(model, "xgb.Booster")) { @@ -120,17 +119,12 @@ xgb.plot.tree <- function(feature_names = NULL, model = NULL, trees = NULL, plot if (NROW(trees) != 1L || !render || show_node_id) { stop("style='xgboost' is only supported for single, rendered tree, without node IDs.") } - if (!is.null(feature_names)) { - stop( - "style='xgboost' cannot override 'feature_names'. Will automatically take them from the model." - ) - } txt <- xgb.dump(model, dump_format = "dot") return(DiagrammeR::grViz(txt[[trees + 1]], width = plot_width, height = plot_height)) } - dt <- xgb.model.dt.tree(feature_names = feature_names, model = model, trees = trees) + dt <- xgb.model.dt.tree(model = model, trees = trees) dt[, label := paste0(Feature, "\nCover: ", Cover, ifelse(Feature == "Leaf", "\nValue: ", "\nGain: "), Gain)] if (show_node_id) diff --git a/R-package/man/xgb.importance.Rd b/R-package/man/xgb.importance.Rd index fca1b70c46e3..73b91e8b4b28 100644 --- a/R-package/man/xgb.importance.Rd +++ b/R-package/man/xgb.importance.Rd @@ -5,8 +5,8 @@ \title{Feature importance} \usage{ xgb.importance( - feature_names = NULL, model = NULL, + feature_names = getinfo(model, "feature_name"), trees = NULL, data = NULL, label = NULL, @@ -14,11 +14,11 @@ xgb.importance( ) } \arguments{ +\item{model}{Object of class \code{xgb.Booster}.} + \item{feature_names}{Character vector used to overwrite the feature names of the model. The default is \code{NULL} (use original feature names).} -\item{model}{Object of class \code{xgb.Booster}.} - \item{trees}{An integer vector of tree indices that should be included into the importance calculation (only for the "gbtree" booster). The default (\code{NULL}) parses all trees. diff --git a/R-package/man/xgb.model.dt.tree.Rd b/R-package/man/xgb.model.dt.tree.Rd index e63bd4b10ac2..75f1cd0f4f77 100644 --- a/R-package/man/xgb.model.dt.tree.Rd +++ b/R-package/man/xgb.model.dt.tree.Rd @@ -5,7 +5,6 @@ \title{Parse model text dump} \usage{ xgb.model.dt.tree( - feature_names = NULL, model = NULL, text = NULL, trees = NULL, @@ -14,13 +13,8 @@ xgb.model.dt.tree( ) } \arguments{ -\item{feature_names}{Character vector of feature names. If the model already -contains feature names, those will be used when \code{feature_names=NULL} (default value). - -\if{html}{\out{
}}\preformatted{ Note that, if the model already contains feature names, it's \\bold\{not\} possible to override them here. -}\if{html}{\out{
}}} - -\item{model}{Object of class \code{xgb.Booster}.} +\item{model}{Object of class \code{xgb.Booster}. If it contains feature names (they can be set through +\link{setinfo}), they will be used in the output from this function.} \item{text}{Character vector previously generated by the function \code{\link[=xgb.dump]{xgb.dump()}} (called with parameter \code{with_stats = TRUE}). \code{text} takes precedence over \code{model}.} @@ -81,7 +75,7 @@ bst <- xgboost( # This bst model already has feature_names stored with it, so those would be used when # feature_names is not set: -(dt <- xgb.model.dt.tree(model = bst)) +dt <- xgb.model.dt.tree(bst) # How to match feature names of splits that are following a current 'Yes' branch: merge( diff --git a/R-package/man/xgb.plot.multi.trees.Rd b/R-package/man/xgb.plot.multi.trees.Rd index d98a3482cde4..7fa75c85d886 100644 --- a/R-package/man/xgb.plot.multi.trees.Rd +++ b/R-package/man/xgb.plot.multi.trees.Rd @@ -6,7 +6,6 @@ \usage{ xgb.plot.multi.trees( model, - feature_names = NULL, features_keep = 5, plot_width = NULL, plot_height = NULL, @@ -15,10 +14,8 @@ xgb.plot.multi.trees( ) } \arguments{ -\item{model}{Object of class \code{xgb.Booster}.} - -\item{feature_names}{Character vector used to overwrite the feature names -of the model. The default (\code{NULL}) uses the original feature names.} +\item{model}{Object of class \code{xgb.Booster}. If it contains feature names (they can be set through +\link{setinfo}), they will be used in the output from this function.} \item{features_keep}{Number of features to keep in each position of the multi trees, by default 5.} diff --git a/R-package/man/xgb.plot.tree.Rd b/R-package/man/xgb.plot.tree.Rd index a09bb7183297..69d37301dde6 100644 --- a/R-package/man/xgb.plot.tree.Rd +++ b/R-package/man/xgb.plot.tree.Rd @@ -5,7 +5,6 @@ \title{Plot boosted trees} \usage{ xgb.plot.tree( - feature_names = NULL, model = NULL, trees = NULL, plot_width = NULL, @@ -17,10 +16,8 @@ xgb.plot.tree( ) } \arguments{ -\item{feature_names}{Character vector used to overwrite the feature names -of the model. The default (\code{NULL}) uses the original feature names.} - -\item{model}{Object of class \code{xgb.Booster}.} +\item{model}{Object of class \code{xgb.Booster}. If it contains feature names (they can be set through +\link{setinfo}), they will be used in the output from this function.} \item{trees}{An integer vector of tree indices that should be used. The default (\code{NULL}) uses all trees. diff --git a/R-package/tests/testthat/test_helpers.R b/R-package/tests/testthat/test_helpers.R index 372f2520c26f..badac0213292 100644 --- a/R-package/tests/testthat/test_helpers.R +++ b/R-package/tests/testthat/test_helpers.R @@ -282,9 +282,6 @@ test_that("xgb.model.dt.tree works with and without feature names", { expect_equal(dim(dt.tree), c(188, 10)) expect_output(str(dt.tree), 'Feature.*\\"Age\\"') - dt.tree.0 <- xgb.model.dt.tree(model = bst.Tree) - expect_equal(dt.tree, dt.tree.0) - # when model contains no feature names: dt.tree.x <- xgb.model.dt.tree(model = bst.Tree.unnamed) expect_output(str(dt.tree.x), 'Feature.*\\"3\\"') @@ -304,7 +301,7 @@ test_that("xgb.model.dt.tree throws error for gblinear", { test_that("xgb.importance works with and without feature names", { .skip_if_vcd_not_available() - importance.Tree <- xgb.importance(feature_names = feature.names, model = bst.Tree) + importance.Tree <- xgb.importance(feature_names = feature.names, model = bst.Tree.unnamed) if (!flag_32bit) expect_equal(dim(importance.Tree), c(7, 4)) expect_equal(colnames(importance.Tree), c("Feature", "Gain", "Cover", "Frequency")) @@ -330,9 +327,8 @@ test_that("xgb.importance works with and without feature names", { importance <- xgb.importance(feature_names = feature.names, model = bst.Tree, trees = trees) importance_from_dump <- function() { - model_text_dump <- xgb.dump(model = bst.Tree.unnamed, with_stats = TRUE, trees = trees) + model_text_dump <- xgb.dump(model = bst.Tree, with_stats = TRUE, trees = trees) imp <- xgb.model.dt.tree( - feature_names = feature.names, text = model_text_dump, trees = trees )[