Skip to content

Commit

Permalink
[R] Remove unusable 'feature_names' argument and make 'model' first a…
Browse files Browse the repository at this point in the history
…rgument in inspection functions (#9939)
  • Loading branch information
david-cortes authored Jan 15, 2024
1 parent 1168a68 commit 547abb8
Show file tree
Hide file tree
Showing 10 changed files with 30 additions and 74 deletions.
2 changes: 1 addition & 1 deletion R-package/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,6 @@ Imports:
data.table (>= 1.9.6),
jsonlite (>= 1.0)
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.0
Encoding: UTF-8
SystemRequirements: GNU make, C++17
9 changes: 1 addition & 8 deletions R-package/R/xgb.importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -113,19 +113,12 @@
#' xgb.importance(model = mbst)
#'
#' @export
xgb.importance <- function(feature_names = NULL, model = NULL, trees = NULL,
xgb.importance <- function(model = NULL, feature_names = getinfo(model, "feature_name"), trees = NULL,
data = NULL, label = NULL, target = NULL) {

if (!(is.null(data) && is.null(label) && is.null(target)))
warning("xgb.importance: parameters 'data', 'label' and 'target' are deprecated")

if (is.null(feature_names)) {
model_feature_names <- xgb.feature_names(model)
if (NROW(model_feature_names)) {
feature_names <- model_feature_names
}
}

if (!(is.null(feature_names) || is.character(feature_names)))
stop("feature_names: Has to be a character vector")

Expand Down
35 changes: 10 additions & 25 deletions R-package/R/xgb.model.dt.tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@
#'
#' Parse a boosted tree model text dump into a `data.table` structure.
#'
#' @param feature_names Character vector of feature names. If the model already
#' contains feature names, those will be used when \code{feature_names=NULL} (default value).
#'
#' Note that, if the model already contains feature names, it's \bold{not} possible to override them here.
#' @param model Object of class `xgb.Booster`.
#' @param model Object of class `xgb.Booster`. If it contains feature names (they can be set through
#' \link{setinfo}), they will be used in the output from this function.
#' @param text Character vector previously generated by the function [xgb.dump()]
#' (called with parameter `with_stats = TRUE`). `text` takes precedence over `model`.
#' @param trees An integer vector of tree indices that should be used.
Expand Down Expand Up @@ -58,7 +55,7 @@
#'
#' # This bst model already has feature_names stored with it, so those would be used when
#' # feature_names is not set:
#' (dt <- xgb.model.dt.tree(model = bst))
#' dt <- xgb.model.dt.tree(bst)
#'
#' # How to match feature names of splits that are following a current 'Yes' branch:
#' merge(
Expand All @@ -69,7 +66,7 @@
#' ]
#'
#' @export
xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
xgb.model.dt.tree <- function(model = NULL, text = NULL,
trees = NULL, use_int_id = FALSE, ...) {
check.deprecation(...)

Expand All @@ -79,24 +76,15 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
" (or NULL if 'model' was provided).")
}

model_feature_names <- NULL
if (inherits(model, "xgb.Booster")) {
model_feature_names <- xgb.feature_names(model)
if (NROW(model_feature_names) && !is.null(feature_names)) {
stop("'model' contains feature names. Cannot override them.")
}
}
if (is.null(feature_names) && !is.null(model) && !is.null(model_feature_names))
feature_names <- model_feature_names

if (!(is.null(feature_names) || is.character(feature_names))) {
stop("feature_names: must be a character vector")
}

if (!(is.null(trees) || is.numeric(trees))) {
stop("trees: must be a vector of integers.")
}

feature_names <- NULL
if (inherits(model, "xgb.Booster")) {
feature_names <- xgb.feature_names(model)
}

from_text <- TRUE
if (is.null(text)) {
text <- xgb.dump(model = model, with_stats = TRUE)
Expand Down Expand Up @@ -134,7 +122,7 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
branch_rx_w_names <- paste0("\\d+:\\[(.+)<(", anynumber_regex, ")\\] yes=(\\d+),no=(\\d+),missing=(\\d+),",
"gain=(", anynumber_regex, "),cover=(", anynumber_regex, ")")
text_has_feature_names <- FALSE
if (NROW(model_feature_names)) {
if (NROW(feature_names)) {
branch_rx <- branch_rx_w_names
text_has_feature_names <- TRUE
} else {
Expand All @@ -148,9 +136,6 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
}
}
}
if (text_has_feature_names && is.null(model) && !is.null(feature_names)) {
stop("'text' contains feature names. Cannot override them.")
}
branch_cols <- c("Feature", "Split", "Yes", "No", "Missing", "Gain", "Cover")
td[
isLeaf == FALSE,
Expand Down
4 changes: 2 additions & 2 deletions R-package/R/xgb.plot.multi.trees.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@
#' }
#'
#' @export
xgb.plot.multi.trees <- function(model, feature_names = NULL, features_keep = 5, plot_width = NULL, plot_height = NULL,
xgb.plot.multi.trees <- function(model, features_keep = 5, plot_width = NULL, plot_height = NULL,
render = TRUE, ...) {
if (!requireNamespace("DiagrammeR", quietly = TRUE)) {
stop("DiagrammeR is required for xgb.plot.multi.trees")
}
check.deprecation(...)
tree.matrix <- xgb.model.dt.tree(feature_names = feature_names, model = model)
tree.matrix <- xgb.model.dt.tree(model = model)

# first number of the path represents the tree, then the following numbers are related to the path to follow
# root init
Expand Down
14 changes: 4 additions & 10 deletions R-package/R/xgb.plot.tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
#'
#' Read a tree model text dump and plot the model.
#'
#' @param feature_names Character vector used to overwrite the feature names
#' of the model. The default (`NULL`) uses the original feature names.
#' @param model Object of class `xgb.Booster`.
#' @param model Object of class `xgb.Booster`. If it contains feature names (they can be set through
#' \link{setinfo}), they will be used in the output from this function.
#' @param trees An integer vector of tree indices that should be used.
#' The default (`NULL`) uses all trees.
#' Useful, e.g., in multiclass classification to get only
Expand Down Expand Up @@ -103,7 +102,7 @@
#' }
#'
#' @export
xgb.plot.tree <- function(feature_names = NULL, model = NULL, trees = NULL, plot_width = NULL, plot_height = NULL,
xgb.plot.tree <- function(model = NULL, trees = NULL, plot_width = NULL, plot_height = NULL,
render = TRUE, show_node_id = FALSE, style = c("R", "xgboost"), ...) {
check.deprecation(...)
if (!inherits(model, "xgb.Booster")) {
Expand All @@ -120,17 +119,12 @@ xgb.plot.tree <- function(feature_names = NULL, model = NULL, trees = NULL, plot
if (NROW(trees) != 1L || !render || show_node_id) {
stop("style='xgboost' is only supported for single, rendered tree, without node IDs.")
}
if (!is.null(feature_names)) {
stop(
"style='xgboost' cannot override 'feature_names'. Will automatically take them from the model."
)
}

txt <- xgb.dump(model, dump_format = "dot")
return(DiagrammeR::grViz(txt[[trees + 1]], width = plot_width, height = plot_height))
}

dt <- xgb.model.dt.tree(feature_names = feature_names, model = model, trees = trees)
dt <- xgb.model.dt.tree(model = model, trees = trees)

dt[, label := paste0(Feature, "\nCover: ", Cover, ifelse(Feature == "Leaf", "\nValue: ", "\nGain: "), Gain)]
if (show_node_id)
Expand Down
6 changes: 3 additions & 3 deletions R-package/man/xgb.importance.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 3 additions & 9 deletions R-package/man/xgb.model.dt.tree.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 2 additions & 5 deletions R-package/man/xgb.plot.multi.trees.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 2 additions & 5 deletions R-package/man/xgb.plot.tree.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 2 additions & 6 deletions R-package/tests/testthat/test_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,6 @@ test_that("xgb.model.dt.tree works with and without feature names", {
expect_equal(dim(dt.tree), c(188, 10))
expect_output(str(dt.tree), 'Feature.*\\"Age\\"')

dt.tree.0 <- xgb.model.dt.tree(model = bst.Tree)
expect_equal(dt.tree, dt.tree.0)

# when model contains no feature names:
dt.tree.x <- xgb.model.dt.tree(model = bst.Tree.unnamed)
expect_output(str(dt.tree.x), 'Feature.*\\"3\\"')
Expand All @@ -304,7 +301,7 @@ test_that("xgb.model.dt.tree throws error for gblinear", {

test_that("xgb.importance works with and without feature names", {
.skip_if_vcd_not_available()
importance.Tree <- xgb.importance(feature_names = feature.names, model = bst.Tree)
importance.Tree <- xgb.importance(feature_names = feature.names, model = bst.Tree.unnamed)
if (!flag_32bit)
expect_equal(dim(importance.Tree), c(7, 4))
expect_equal(colnames(importance.Tree), c("Feature", "Gain", "Cover", "Frequency"))
Expand All @@ -330,9 +327,8 @@ test_that("xgb.importance works with and without feature names", {
importance <- xgb.importance(feature_names = feature.names, model = bst.Tree, trees = trees)

importance_from_dump <- function() {
model_text_dump <- xgb.dump(model = bst.Tree.unnamed, with_stats = TRUE, trees = trees)
model_text_dump <- xgb.dump(model = bst.Tree, with_stats = TRUE, trees = trees)
imp <- xgb.model.dt.tree(
feature_names = feature.names,
text = model_text_dump,
trees = trees
)[
Expand Down

0 comments on commit 547abb8

Please sign in to comment.