Skip to content

Commit

Permalink
Merge branch 'master' into fix_dt_warning
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Sep 28, 2024
2 parents cca4234 + 521324b commit 02e9ca8
Show file tree
Hide file tree
Showing 277 changed files with 10,719 additions and 9,997 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/jvm_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
with:
submodules: 'true'

- uses: actions/setup-java@99b8673ff64fbf99d8d325f52d9a5bdedb8483e9 # v4.2.1
- uses: actions/setup-java@6a0805fcefea3d4657a47ac4c165951e33482018 # v4.2.2
with:
distribution: 'temurin'
java-version: '8'
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ jobs:
- uses: actions/checkout@a5ac7e51b41094c92402da3b24376905380afc29 # v4.1.6
with:
submodules: 'true'
- uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # v5.1.1
- uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
with:
python-version: "3.10"
architecture: 'x64'
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/python_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ jobs:
submodules: 'true'

- name: Set up Python 3.10
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # v5.1.1
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
with:
python-version: "3.10"

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/r_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ jobs:
key: ${{ runner.os }}-r-${{ matrix.config.r }}-7-${{ hashFiles('R-package/DESCRIPTION') }}
restore-keys: ${{ runner.os }}-r-${{ matrix.config.r }}-7-${{ hashFiles('R-package/DESCRIPTION') }}

- uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # v5.1.1
- uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
with:
python-version: "3.10"
architecture: 'x64'
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/scorecards.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
# Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF
# format to the repository Actions tab.
- name: "Upload artifact"
uses: actions/upload-artifact@0b2256b8c012f0828dc542b3febcab082c67f72b # v4.3.4
uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0
with:
name: SARIF file
path: results.sarif
Expand Down
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@
*vali
*sdf
Release
*exe*
*exe
*exp
ipch
*.filters
*.user
*log
rmm_log.txt
Debug
*suo
.Rhistory
Expand Down
1 change: 0 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,6 @@ if(USE_CUDA)
if(DEFINED GPU_COMPUTE_VER)
compute_cmake_cuda_archs("${GPU_COMPUTE_VER}")
endif()
add_subdirectory(${PROJECT_SOURCE_DIR}/gputreeshap)

find_package(CUDAToolkit REQUIRED)
find_package(CCCL CONFIG)
Expand Down
19 changes: 15 additions & 4 deletions R-package/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,8 @@ NULL
#' its own serializers with better compatibility guarantees, which allow loading
#' said models in other language bindings of XGBoost.
#'
#' Note that an `xgb.Booster` object, outside of its core components, might also keep:
#' Note that an `xgb.Booster` object (**as produced by [xgb.train()]**, see rest of the doc
#' for objects produced by [xgboost()]), outside of its core components, might also keep:
#' - Additional model configuration (accessible through [xgb.config()]), which includes
#' model fitting parameters like `max_depth` and runtime parameters like `nthread`.
#' These are not necessarily useful for prediction/importance/plotting.
Expand All @@ -450,6 +451,16 @@ NULL
#' not used for prediction / importance / plotting / etc.
#' These R attributes are only preserved when using R's serializers.
#'
#' In addition to the regular `xgb.Booster` objects producted by [xgb.train()], the
#' function [xgboost()] produces a different subclass `xgboost`, which keeps other
#' additional metadata as R attributes such as class names in classification problems,
#' and which has a dedicated `predict` method that uses different defaults. XGBoost's
#' own serializers can work with this `xgboost` class, but as they do not keep R
#' attributes, the resulting object, when deserialized, is downcasted to the regular
#' `xgb.Booster` class (i.e. it loses the metadata, and the resulting object will use
#' `predict.xgb.Booster` instead of `predict.xgboost`) - for these `xgboost` objects,
#' `saveRDS` might thus be a better option if the extra functionalities are needed.
#'
#' Note that XGBoost models in R starting from version `2.1.0` and onwards, and
#' XGBoost models before version `2.1.0`; have a very different R object structure and
#' are incompatible with each other. Hence, models that were saved with R serializers
Expand All @@ -474,9 +485,9 @@ NULL
#' as part of another R object.
#'
#' Use [saveRDS()] if you require the R-specific attributes that a booster might have, such
#' as evaluation logs, but note that future compatibility of such objects is outside XGBoost's
#' control as it relies on R's serialization format (see e.g. the details section in
#' [serialize] and [save()] from base R).
#' as evaluation logs or the model class `xgboost` instead of `xgb.Booster`, but note that
#' future compatibility of such objects is outside XGBoost's control as it relies on R's
#' serialization format (see e.g. the details section in [serialize] and [save()] from base R).
#'
#' For more details and explanation about model persistence and archival, consult the page
#' \url{https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html}.
Expand Down
28 changes: 27 additions & 1 deletion R-package/R/xgb.ggplot.R
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,27 @@ xgb.ggplot.deepness <- function(model = NULL, which = c("2x1", "max.depth", "med
#' @export
xgb.ggplot.shap.summary <- function(data, shap_contrib = NULL, features = NULL, top_n = 10, model = NULL,
trees = NULL, target_class = NULL, approxcontrib = FALSE, subsample = NULL) {
if (inherits(data, "xgb.DMatrix")) {
stop(
"'xgb.ggplot.shap.summary' is not compatible with 'xgb.DMatrix' objects. Try passing a matrix or data.frame."
)
}
cols_categ <- NULL
if (!is.null(model)) {
ftypes <- getinfo(model, "feature_type")
if (NROW(ftypes)) {
if (length(ftypes) != ncol(data)) {
stop(sprintf("'data' has incorrect number of columns (expected: %d, got: %d).", length(ftypes), ncol(data)))
}
cols_categ <- colnames(data)[ftypes == "c"]
}
} else if (inherits(data, "data.frame")) {
cols_categ <- names(data)[sapply(data, function(x) is.factor(x) || is.character(x))]
}
if (NROW(cols_categ)) {
warning("Categorical features are ignored in 'xgb.ggplot.shap.summary'.")
}

data_list <- xgb.shap.data(
data = data,
shap_contrib = shap_contrib,
Expand All @@ -114,6 +135,10 @@ xgb.ggplot.shap.summary <- function(data, shap_contrib = NULL, features = NULL,
subsample = subsample,
max_observations = 10000 # 10,000 samples per feature.
)
if (NROW(cols_categ)) {
data_list <- lapply(data_list, function(x) x[, !(colnames(x) %in% cols_categ), drop = FALSE])
}

p_data <- prepare.ggplot.shap.data(data_list, normalize = TRUE)
# Reverse factor levels so that the first level is at the top of the plot
p_data[, "feature" := factor(feature, rev(levels(feature)))]
Expand All @@ -134,7 +159,8 @@ xgb.ggplot.shap.summary <- function(data, shap_contrib = NULL, features = NULL,
#' @param data_list The result of `xgb.shap.data()`.
#' @param normalize Whether to standardize feature values to mean 0 and
#' standard deviation 1. This is useful for comparing multiple features on the same
#' plot. Default is `FALSE`.
#' plot. Default is `FALSE`. Note that it cannot be used when the data contains
#' categorical features.
#' @return A `data.table` containing the observation ID, the feature name, the
#' feature value (normalized if specified), and the SHAP contribution value.
#' @noRd
Expand Down
2 changes: 1 addition & 1 deletion R-package/R/xgb.model.dt.tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ xgb.model.dt.tree <- function(model = NULL, text = NULL,
from_text <- FALSE
}

if (length(text) < 2 || !any(grepl('leaf=(\\d+)', text))) {
if (length(text) < 2 || !any(grepl('leaf=(-?\\d+)', text))) {
stop("Non-tree model detected! This function can only be used with tree models.")
}

Expand Down
18 changes: 14 additions & 4 deletions R-package/R/xgb.plot.shap.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#'
#' Visualizes SHAP values against feature values to gain an impression of feature effects.
#'
#' @param data The data to explain as a `matrix` or `dgCMatrix`.
#' @param data The data to explain as a `matrix`, `dgCMatrix`, or `data.frame`.
#' @param shap_contrib Matrix of SHAP contributions of `data`.
#' The default (`NULL`) computes it from `model` and `data`.
#' @param features Vector of column indices or feature names to plot. When `NULL`
Expand Down Expand Up @@ -285,8 +285,11 @@ xgb.plot.shap.summary <- function(data, shap_contrib = NULL, features = NULL, to
xgb.shap.data <- function(data, shap_contrib = NULL, features = NULL, top_n = 1, model = NULL,
trees = NULL, target_class = NULL, approxcontrib = FALSE,
subsample = NULL, max_observations = 100000) {
if (!is.matrix(data) && !inherits(data, "dgCMatrix"))
stop("data: must be either matrix or dgCMatrix")
if (!inherits(data, c("matrix", "dsparseMatrix", "data.frame")))
stop("data: must be matrix, sparse matrix, or data.frame.")
if (inherits(data, "data.frame") && length(class(data)) > 1L) {
data <- as.data.frame(data)
}

if (is.null(shap_contrib) && (is.null(model) || !inherits(model, "xgb.Booster")))
stop("when shap_contrib is not provided, one must provide an xgb.Booster model")
Expand All @@ -311,7 +314,14 @@ xgb.shap.data <- function(data, shap_contrib = NULL, features = NULL, top_n = 1,
stop("if model has no feature_names, columns in `data` must match features in model")

if (!is.null(subsample)) {
idx <- sample(x = seq_len(nrow(data)), size = as.integer(subsample * nrow(data)), replace = FALSE)
if (subsample <= 0 || subsample >= 1) {
stop("'subsample' must be a number between zero and one (non-inclusive).")
}
sample_size <- as.integer(subsample * nrow(data))
if (sample_size < 2) {
stop("Sampling fraction involves less than 2 rows.")
}
idx <- sample(x = seq_len(nrow(data)), size = sample_size, replace = FALSE)
} else {
idx <- seq_len(min(nrow(data), max_observations))
}
Expand Down
19 changes: 15 additions & 4 deletions R-package/man/a-compatibility-note-for-saveRDS-save.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion R-package/man/xgb.plot.shap.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion R-package/man/xgb.plot.shap.summary.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

58 changes: 58 additions & 0 deletions R-package/tests/testthat/test_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,26 @@ test_that("xgb.shap.data works with subsampling", {
expect_equal(NROW(data_list$data), NROW(data_list$shap_contrib))
})

test_that("xgb.shap.data works with data frames", {
data(mtcars)
df <- mtcars
df$cyl <- factor(df$cyl)
x <- df[, -1]
y <- df$mpg
dm <- xgb.DMatrix(x, label = y, nthread = 1L)
model <- xgb.train(
data = dm,
params = list(
max_depth = 2,
nthread = 1
),
nrounds = 2
)
data_list <- xgb.shap.data(data = df[, -1], model = model, top_n = 2, subsample = 0.8)
expect_equal(NROW(data_list$data), as.integer(0.8 * nrow(df)))
expect_equal(NROW(data_list$data), NROW(data_list$shap_contrib))
})

test_that("prepare.ggplot.shap.data works", {
.skip_if_vcd_not_available()
data_list <- xgb.shap.data(data = sparse_matrix, model = bst.Tree, top_n = 2)
Expand All @@ -472,6 +492,44 @@ test_that("xgb.plot.shap.summary works", {
expect_silent(xgb.ggplot.shap.summary(data = sparse_matrix, model = bst.Tree, top_n = 2))
})

test_that("xgb.plot.shap.summary ignores categorical features", {
.skip_if_vcd_not_available()
data(mtcars)
df <- mtcars
df$cyl <- factor(df$cyl)
levels(df$cyl) <- c("a", "b", "c")
x <- df[, -1]
y <- df$mpg
dm <- xgb.DMatrix(x, label = y, nthread = 1L)
model <- xgb.train(
data = dm,
params = list(
max_depth = 2,
nthread = 1
),
nrounds = 2
)
expect_warning({
xgb.ggplot.shap.summary(data = x, model = model, top_n = 2)
})

x_num <- mtcars[, -1]
x_num$gear <- as.numeric(x_num$gear) - 1
x_num <- as.matrix(x_num)
dm <- xgb.DMatrix(x_num, label = y, feature_types = c(rep("q", 8), "c", "q"), nthread = 1L)
model <- xgb.train(
data = dm,
params = list(
max_depth = 2,
nthread = 1
),
nrounds = 2
)
expect_warning({
xgb.ggplot.shap.summary(data = x_num, model = model, top_n = 2)
})
})

test_that("check.deprecation works", {
ttt <- function(a = NNULL, DUMMY = NULL, ...) {
check.deprecation(...)
Expand Down
Loading

0 comments on commit 02e9ca8

Please sign in to comment.