Skip to content

Commit

Permalink
[R] Add class names to coefficients (#10745)
Browse files Browse the repository at this point in the history
  • Loading branch information
david-cortes authored Aug 24, 2024
1 parent fd0138c commit 479ae80
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
16 changes: 12 additions & 4 deletions R-package/R/xgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -1109,17 +1109,25 @@ coef.xgb.Booster <- function(object, ...) {
if (n_cols == 1L) {
out <- c(intercepts, coefs)
if (add_names) {
names(out) <- feature_names
.Call(XGSetVectorNamesInplace_R, out, feature_names)
}
} else {
coefs <- matrix(coefs, nrow = num_feature, byrow = TRUE)
dim(intercepts) <- c(1L, n_cols)
out <- rbind(intercepts, coefs)
out_names <- vector(mode = "list", length = 2)
if (add_names) {
row.names(out) <- feature_names
out_names[[1L]] <- feature_names
}
# TODO: if a class names attributes is added,
# should use those names here.
if (inherits(object, "xgboost")) {
metadata <- attributes(object)$metadata
if (NROW(metadata$y_levels)) {
out_names[[2L]] <- metadata$y_levels
} else if (NROW(metadata$y_names)) {
out_names[[2L]] <- metadata$y_names
}
}
.Call(XGSetArrayDimNamesInplace_R, out, out_names)
}
return(out)
}
Expand Down
13 changes: 13 additions & 0 deletions R-package/tests/testthat/test_basic.R
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,19 @@ test_that("Coefficients from gblinear have the expected shape and names", {
pred_auto <- predict(model, x, outputmargin = TRUE)
pred_manual <- unname(mm %*% coefs)
expect_equal(pred_manual, pred_auto, tolerance = 1e-7)

# xgboost() with additional metadata
model <- xgboost(
iris[, -5],
iris$Species,
booster = "gblinear",
objective = "multi:softprob",
nrounds = 3,
nthread = 1
)
coefs <- coef(model)
expect_equal(row.names(coefs), c("(Intercept)", colnames(x)))
expect_equal(colnames(coefs), levels(iris$Species))
})

test_that("Deep copies work as expected", {
Expand Down

0 comments on commit 479ae80

Please sign in to comment.