Skip to content

Commit

Permalink
[R] Allow passing data.frame to SHAP (#10744)
Browse files Browse the repository at this point in the history
  • Loading branch information
david-cortes authored Sep 2, 2024
1 parent ec8cfb3 commit f52f11e
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 7 deletions.
28 changes: 27 additions & 1 deletion R-package/R/xgb.ggplot.R
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,27 @@ xgb.ggplot.deepness <- function(model = NULL, which = c("2x1", "max.depth", "med
#' @export
xgb.ggplot.shap.summary <- function(data, shap_contrib = NULL, features = NULL, top_n = 10, model = NULL,
trees = NULL, target_class = NULL, approxcontrib = FALSE, subsample = NULL) {
if (inherits(data, "xgb.DMatrix")) {
stop(
"'xgb.ggplot.shap.summary' is not compatible with 'xgb.DMatrix' objects. Try passing a matrix or data.frame."
)
}
cols_categ <- NULL
if (!is.null(model)) {
ftypes <- getinfo(model, "feature_type")
if (NROW(ftypes)) {
if (length(ftypes) != ncol(data)) {
stop(sprintf("'data' has incorrect number of columns (expected: %d, got: %d).", length(ftypes), ncol(data)))
}
cols_categ <- colnames(data)[ftypes == "c"]
}
} else if (inherits(data, "data.frame")) {
cols_categ <- names(data)[sapply(data, function(x) is.factor(x) || is.character(x))]
}
if (NROW(cols_categ)) {
warning("Categorical features are ignored in 'xgb.ggplot.shap.summary'.")
}

data_list <- xgb.shap.data(
data = data,
shap_contrib = shap_contrib,
Expand All @@ -114,6 +135,10 @@ xgb.ggplot.shap.summary <- function(data, shap_contrib = NULL, features = NULL,
subsample = subsample,
max_observations = 10000 # 10,000 samples per feature.
)
if (NROW(cols_categ)) {
data_list <- lapply(data_list, function(x) x[, !(colnames(x) %in% cols_categ), drop = FALSE])
}

p_data <- prepare.ggplot.shap.data(data_list, normalize = TRUE)
# Reverse factor levels so that the first level is at the top of the plot
p_data[, "feature" := factor(feature, rev(levels(feature)))]
Expand All @@ -134,7 +159,8 @@ xgb.ggplot.shap.summary <- function(data, shap_contrib = NULL, features = NULL,
#' @param data_list The result of `xgb.shap.data()`.
#' @param normalize Whether to standardize feature values to mean 0 and
#' standard deviation 1. This is useful for comparing multiple features on the same
#' plot. Default is `FALSE`.
#' plot. Default is `FALSE`. Note that it cannot be used when the data contains
#' categorical features.
#' @return A `data.table` containing the observation ID, the feature name, the
#' feature value (normalized if specified), and the SHAP contribution value.
#' @noRd
Expand Down
18 changes: 14 additions & 4 deletions R-package/R/xgb.plot.shap.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#'
#' Visualizes SHAP values against feature values to gain an impression of feature effects.
#'
#' @param data The data to explain as a `matrix` or `dgCMatrix`.
#' @param data The data to explain as a `matrix`, `dgCMatrix`, or `data.frame`.
#' @param shap_contrib Matrix of SHAP contributions of `data`.
#' The default (`NULL`) computes it from `model` and `data`.
#' @param features Vector of column indices or feature names to plot. When `NULL`
Expand Down Expand Up @@ -285,8 +285,11 @@ xgb.plot.shap.summary <- function(data, shap_contrib = NULL, features = NULL, to
xgb.shap.data <- function(data, shap_contrib = NULL, features = NULL, top_n = 1, model = NULL,
trees = NULL, target_class = NULL, approxcontrib = FALSE,
subsample = NULL, max_observations = 100000) {
if (!is.matrix(data) && !inherits(data, "dgCMatrix"))
stop("data: must be either matrix or dgCMatrix")
if (!inherits(data, c("matrix", "dsparseMatrix", "data.frame")))
stop("data: must be matrix, sparse matrix, or data.frame.")
if (inherits(data, "data.frame") && length(class(data)) > 1L) {
data <- as.data.frame(data)
}

if (is.null(shap_contrib) && (is.null(model) || !inherits(model, "xgb.Booster")))
stop("when shap_contrib is not provided, one must provide an xgb.Booster model")
Expand All @@ -311,7 +314,14 @@ xgb.shap.data <- function(data, shap_contrib = NULL, features = NULL, top_n = 1,
stop("if model has no feature_names, columns in `data` must match features in model")

if (!is.null(subsample)) {
idx <- sample(x = seq_len(nrow(data)), size = as.integer(subsample * nrow(data)), replace = FALSE)
if (subsample <= 0 || subsample >= 1) {
stop("'subsample' must be a number between zero and one (non-inclusive).")
}
sample_size <- as.integer(subsample * nrow(data))
if (sample_size < 2) {
stop("Sampling fraction involves less than 2 rows.")
}
idx <- sample(x = seq_len(nrow(data)), size = sample_size, replace = FALSE)
} else {
idx <- seq_len(min(nrow(data), max_observations))
}
Expand Down
2 changes: 1 addition & 1 deletion R-package/man/xgb.plot.shap.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion R-package/man/xgb.plot.shap.summary.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

58 changes: 58 additions & 0 deletions R-package/tests/testthat/test_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,26 @@ test_that("xgb.shap.data works with subsampling", {
expect_equal(NROW(data_list$data), NROW(data_list$shap_contrib))
})

test_that("xgb.shap.data works with data frames", {
data(mtcars)
df <- mtcars
df$cyl <- factor(df$cyl)
x <- df[, -1]
y <- df$mpg
dm <- xgb.DMatrix(x, label = y, nthread = 1L)
model <- xgb.train(
data = dm,
params = list(
max_depth = 2,
nthread = 1
),
nrounds = 2
)
data_list <- xgb.shap.data(data = df[, -1], model = model, top_n = 2, subsample = 0.8)
expect_equal(NROW(data_list$data), as.integer(0.8 * nrow(df)))
expect_equal(NROW(data_list$data), NROW(data_list$shap_contrib))
})

test_that("prepare.ggplot.shap.data works", {
.skip_if_vcd_not_available()
data_list <- xgb.shap.data(data = sparse_matrix, model = bst.Tree, top_n = 2)
Expand All @@ -472,6 +492,44 @@ test_that("xgb.plot.shap.summary works", {
expect_silent(xgb.ggplot.shap.summary(data = sparse_matrix, model = bst.Tree, top_n = 2))
})

test_that("xgb.plot.shap.summary ignores categorical features", {
.skip_if_vcd_not_available()
data(mtcars)
df <- mtcars
df$cyl <- factor(df$cyl)
levels(df$cyl) <- c("a", "b", "c")
x <- df[, -1]
y <- df$mpg
dm <- xgb.DMatrix(x, label = y, nthread = 1L)
model <- xgb.train(
data = dm,
params = list(
max_depth = 2,
nthread = 1
),
nrounds = 2
)
expect_warning({
xgb.ggplot.shap.summary(data = x, model = model, top_n = 2)
})

x_num <- mtcars[, -1]
x_num$gear <- as.numeric(x_num$gear) - 1
x_num <- as.matrix(x_num)
dm <- xgb.DMatrix(x_num, label = y, feature_types = c(rep("q", 8), "c", "q"), nthread = 1L)
model <- xgb.train(
data = dm,
params = list(
max_depth = 2,
nthread = 1
),
nrounds = 2
)
expect_warning({
xgb.ggplot.shap.summary(data = x_num, model = model, top_n = 2)
})
})

test_that("check.deprecation works", {
ttt <- function(a = NNULL, DUMMY = NULL, ...) {
check.deprecation(...)
Expand Down

0 comments on commit f52f11e

Please sign in to comment.