Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R] On-demand serialization + standardization of attributes #9924

Merged
merged 53 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
09694ac
on-demand serialization, refactor of attributes
david-cortes Dec 24, 2023
ad6490b
solve merge conflicts
david-cortes Dec 26, 2023
27bbdbc
export function for getting booster rounds
david-cortes Dec 26, 2023
88dd947
linter
david-cortes Dec 26, 2023
4a3b5e2
fix incorrect qualifiers
david-cortes Dec 26, 2023
e2331c3
Merge branch 'master' into altrep
david-cortes Dec 26, 2023
147e1cd
remove all references to caret package
david-cortes Dec 26, 2023
e012cce
fix example
david-cortes Dec 26, 2023
f444812
misc fixes
david-cortes Dec 26, 2023
2f30031
allow unsetting booster info
david-cortes Dec 26, 2023
6d4ad8b
remove unused argument
david-cortes Dec 26, 2023
b0054be
more fixes
david-cortes Dec 26, 2023
4050b6f
missing import
david-cortes Dec 26, 2023
2e16f73
swap 'static' with 'namespace'
david-cortes Dec 27, 2023
70affd5
improve wording on compatibility note
david-cortes Dec 27, 2023
af6cdbf
fix non-executed tests and potentially incorrect 'niter_init'
david-cortes Dec 28, 2023
22d4dd7
linter
david-cortes Dec 29, 2023
0465f57
solve merge conflicts
david-cortes Dec 30, 2023
74d5d55
more doc specificity about nrounds reset
david-cortes Dec 30, 2023
1bb74d8
correct function name
david-cortes Dec 30, 2023
c5d711f
solve merge conflicts
david-cortes Dec 31, 2023
b5ec14e
corrections after merge conflicts
david-cortes Dec 31, 2023
ae0de6d
more corrections after merge conflict
david-cortes Dec 31, 2023
1a3d9f7
solve merge conflicts
david-cortes Jan 3, 2024
041dd2f
updates for new default serialization format
david-cortes Jan 8, 2024
7f39bb0
update name for nrounds getter
david-cortes Jan 8, 2024
02c312c
remove in-place training continuation
david-cortes Jan 8, 2024
b4d59f7
change unserialize -> load.raw
david-cortes Jan 8, 2024
24e256a
use R lists instead of JSON text for xgb.config
david-cortes Jan 8, 2024
c97dc1a
remove internal function for nrounds getter
david-cortes Jan 8, 2024
5f8dea5
use _R suffix for all C functions specific to R
david-cortes Jan 8, 2024
8e29769
add test for C and R attributes with saveRDS
david-cortes Jan 8, 2024
9f81e20
add variable.names method for booster
david-cortes Jan 8, 2024
65197e1
add comment about supressed warning
david-cortes Jan 8, 2024
1b58e1b
solve merge conflicts
david-cortes Jan 8, 2024
7dc9b96
update comment
david-cortes Jan 8, 2024
11b213e
update serializers in vignette
david-cortes Jan 8, 2024
c72e663
update vignettes
david-cortes Jan 8, 2024
df88ad9
remove xgb.serialize and xgb.unserialize
david-cortes Jan 9, 2024
74f7f0c
Update R-package/R/xgb.save.R
david-cortes Jan 9, 2024
1ede32e
update docs
david-cortes Jan 9, 2024
37d6b1e
remove 'keep_extra_attributes'
david-cortes Jan 9, 2024
43d938b
remove .Rnw file
david-cortes Jan 9, 2024
692e5a5
add note about booster's R parameters
david-cortes Jan 10, 2024
c161999
user SerializeToBuffer for internal serialization
david-cortes Jan 10, 2024
feedce5
add test for serialization of config
david-cortes Jan 10, 2024
a02abfc
check more attributes
david-cortes Jan 10, 2024
3285ed6
rewrite compatibility note for serialization
david-cortes Jan 10, 2024
8082256
improve wording
david-cortes Jan 10, 2024
6fa7937
update note about attributes in xgb.save
david-cortes Jan 10, 2024
e02ed8f
Update R-package/R/utils.R
david-cortes Jan 10, 2024
ff70221
Update R-package/R/utils.R
david-cortes Jan 10, 2024
d133258
rebuild docs
david-cortes Jan 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions R-package/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@

S3method("[",xgb.DMatrix)
S3method("dimnames<-",xgb.DMatrix)
S3method(coef,xgb.Booster)
S3method(dim,xgb.DMatrix)
S3method(dimnames,xgb.DMatrix)
S3method(getinfo,xgb.Booster)
S3method(getinfo,xgb.DMatrix)
S3method(predict,xgb.Booster)
S3method(predict,xgb.Booster.handle)
S3method(print,xgb.Booster)
S3method(print,xgb.DMatrix)
S3method(print,xgb.cv.synchronous)
S3method(setinfo,xgb.Booster)
S3method(setinfo,xgb.DMatrix)
S3method(slice,xgb.DMatrix)
export("xgb.attr<-")
Expand All @@ -26,22 +28,24 @@ export(cb.save.model)
export(getinfo)
export(setinfo)
export(slice)
export(xgb.Booster.complete)
export(xgb.DMatrix)
export(xgb.DMatrix.hasinfo)
export(xgb.DMatrix.save)
export(xgb.attr)
export(xgb.attributes)
export(xgb.config)
export(xgb.copy.Booster)
export(xgb.create.features)
export(xgb.cv)
export(xgb.dump)
export(xgb.gblinear.history)
export(xgb.get.Booster.nrounds)
trivialfis marked this conversation as resolved.
Show resolved Hide resolved
export(xgb.get.config)
export(xgb.ggplot.deepness)
export(xgb.ggplot.importance)
export(xgb.ggplot.shap.summary)
export(xgb.importance)
export(xgb.is.same.Booster)
export(xgb.load)
export(xgb.load.raw)
export(xgb.model.dt.tree)
Expand Down Expand Up @@ -83,6 +87,7 @@ importFrom(graphics,points)
importFrom(graphics,title)
importFrom(jsonlite,fromJSON)
importFrom(jsonlite,toJSON)
importFrom(stats,coef)
importFrom(stats,median)
importFrom(stats,predict)
importFrom(utils,head)
Expand Down
68 changes: 37 additions & 31 deletions R-package/R/callbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ cb.reset.parameters <- function(new_params) {
})

if (!is.null(env$bst)) {
xgb.parameters(env$bst$handle) <- pars
xgb.parameters(env$bst) <- pars
} else {
for (fd in env$bst_folds)
xgb.parameters(fd$bst) <- pars
Expand Down Expand Up @@ -333,13 +333,13 @@ cb.early.stop <- function(stopping_rounds, maximize = FALSE,
if (!is.null(env$bst)) {
if (!inherits(env$bst, 'xgb.Booster'))
stop("'bst' in the parent frame must be an 'xgb.Booster'")
if (!is.null(best_score <- xgb.attr(env$bst$handle, 'best_score'))) {
if (!is.null(best_score <- xgb.attr(env$bst, 'best_score'))) {
best_score <<- as.numeric(best_score)
best_iteration <<- as.numeric(xgb.attr(env$bst$handle, 'best_iteration')) + 1
best_msg <<- as.numeric(xgb.attr(env$bst$handle, 'best_msg'))
best_iteration <<- as.numeric(xgb.attr(env$bst, 'best_iteration')) + 1
best_msg <<- as.numeric(xgb.attr(env$bst, 'best_msg'))
} else {
xgb.attributes(env$bst$handle) <- list(best_iteration = best_iteration - 1,
best_score = best_score)
xgb.attributes(env$bst) <- list(best_iteration = best_iteration - 1,
best_score = best_score)
}
} else if (is.null(env$bst_folds) || is.null(env$basket)) {
stop("Parent frame has neither 'bst' nor ('bst_folds' and 'basket')")
Expand All @@ -348,7 +348,7 @@ cb.early.stop <- function(stopping_rounds, maximize = FALSE,

finalizer <- function(env) {
if (!is.null(env$bst)) {
attr_best_score <- as.numeric(xgb.attr(env$bst$handle, 'best_score'))
attr_best_score <- as.numeric(xgb.attr(env$bst, 'best_score'))
if (best_score != attr_best_score) {
# If the difference is too big, throw an error
if (abs(best_score - attr_best_score) >= 1e-14) {
Expand All @@ -358,9 +358,9 @@ cb.early.stop <- function(stopping_rounds, maximize = FALSE,
# If the difference is due to floating-point truncation, update best_score
best_score <- attr_best_score
}
env$bst$best_iteration <- best_iteration
env$bst$best_ntreelimit <- best_ntreelimit
env$bst$best_score <- best_score
xgb.attr(env$bst, "best_iteration") <- best_iteration
xgb.attr(env$bst, "best_ntreelimit") <- best_ntreelimit
xgb.attr(env$bst, "best_score") <- best_score
} else {
env$basket$best_iteration <- best_iteration
env$basket$best_ntreelimit <- best_ntreelimit
Expand Down Expand Up @@ -440,8 +440,12 @@ cb.save.model <- function(save_period = 0, save_name = "xgboost.model") {
stop("'save_model' callback requires the 'bst' booster object in its calling frame")

if ((save_period > 0 && (env$iteration - env$begin_iteration) %% save_period == 0) ||
(save_period == 0 && env$iteration == env$end_iteration))
xgb.save(env$bst, sprintf(save_name, env$iteration))
(save_period == 0 && env$iteration == env$end_iteration)) {
suppressWarnings({
trivialfis marked this conversation as resolved.
Show resolved Hide resolved
save_name <- sprintf(save_name, env$iteration)
})
xgb.save(env$bst, save_name)
}
}
attr(callback, 'call') <- match.call()
attr(callback, 'name') <- 'cb.save.model'
Expand Down Expand Up @@ -512,8 +516,7 @@ cb.cv.predict <- function(save_models = FALSE) {
env$basket$pred <- pred
if (save_models) {
env$basket$models <- lapply(env$bst_folds, function(fd) {
xgb.attr(fd$bst, 'niter') <- env$end_iteration - 1
xgb.Booster.complete(xgb.handleToBooster(handle = fd$bst, raw = NULL), saveraw = TRUE)
return(fd$bst)
})
}
}
Expand Down Expand Up @@ -665,7 +668,7 @@ cb.gblinear.history <- function(sparse = FALSE) {
} else { # xgb.cv:
cf <- vector("list", length(env$bst_folds))
for (i in seq_along(env$bst_folds)) {
dmp <- xgb.dump(xgb.handleToBooster(handle = env$bst_folds[[i]]$bst, raw = NULL))
dmp <- xgb.dump(env$bst_folds[[i]]$bst)
cf[[i]] <- as.numeric(grep('(booster|bias|weigh)', dmp, invert = TRUE, value = TRUE))
if (sparse) cf[[i]] <- as(cf[[i]], "sparseVector")
}
Expand All @@ -685,14 +688,19 @@ cb.gblinear.history <- function(sparse = FALSE) {
callback
}

#' Extract gblinear coefficients history.
#'
#' A helper function to extract the matrix of linear coefficients' history
#' @title Extract gblinear coefficients history.
#' @description A helper function to extract the matrix of linear coefficients' history
#' from a gblinear model created while using the \code{cb.gblinear.history()}
#' callback.
#' @details Note that this is an R-specific function that relies on R attributes that
#' are not saved when using xgboost's own serialization functions like \link{xgb.load}
#' or \link{xgb.load.raw}.
#'
#' In order for a serialized model to be accepted by tgis function, one must use R
#' serializers such as \link{saveRDS}.
#' @param model either an \code{xgb.Booster} or a result of \code{xgb.cv()}, trained
#' using the \code{cb.gblinear.history()} callback.
#' using the \code{cb.gblinear.history()} callback, but \bold{not} a booster
#' loaded from \link{xgb.load} or \link{xgb.load.raw}.
#' @param class_index zero-based class index to extract the coefficients for only that
#' specific class in a multinomial multiclass model. When it is NULL, all the
#' coefficients are returned. Has no effect in non-multiclass models.
Expand All @@ -713,20 +721,18 @@ xgb.gblinear.history <- function(model, class_index = NULL) {
stop("model must be an object of either xgb.Booster or xgb.cv.synchronous class")
is_cv <- inherits(model, "xgb.cv.synchronous")

if (is.null(model[["callbacks"]]) || is.null(model$callbacks[["cb.gblinear.history"]]))
if (is_cv) {
callbacks <- model$callbacks
} else {
callbacks <- attributes(model)$callbacks
}

if (is.null(callbacks) || is.null(callbacks$cb.gblinear.history))
stop("model must be trained while using the cb.gblinear.history() callback")

if (!is_cv) {
# extract num_class & num_feat from the internal model
dmp <- xgb.dump(model)
if (length(dmp) < 2 || dmp[2] != "bias:")
stop("It does not appear to be a gblinear model")
dmp <- dmp[-c(1, 2)]
n <- which(dmp == 'weight:')
if (length(n) != 1)
stop("It does not appear to be a gblinear model")
num_class <- n - 1
num_feat <- (length(dmp) - 4) / num_class
num_class <- xgb.num_class(model)
num_feat <- xgb.num_feature(model)
} else {
# in case of CV, the object is expected to have this info
if (model$params$booster != "gblinear")
Expand All @@ -742,7 +748,7 @@ xgb.gblinear.history <- function(model, class_index = NULL) {
(class_index[1] < 0 || class_index[1] >= num_class))
stop("class_index has to be within [0,", num_class - 1, "]")

coef_path <- environment(model$callbacks$cb.gblinear.history)[["coefs"]]
coef_path <- environment(callbacks$cb.gblinear.history)[["coefs"]]
if (!is.null(class_index) && num_class > 1) {
coef_path <- if (is.list(coef_path)) {
lapply(coef_path,
Expand Down
29 changes: 17 additions & 12 deletions R-package/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -148,19 +148,17 @@ check.custom.eval <- function(env = parent.frame()) {


# Update a booster handle for an iteration with dtrain data
xgb.iter.update <- function(booster_handle, dtrain, iter, obj) {
if (!identical(class(booster_handle), "xgb.Booster.handle")) {
stop("booster_handle must be of xgb.Booster.handle class")
}
xgb.iter.update <- function(bst, dtrain, iter, obj) {
if (!inherits(dtrain, "xgb.DMatrix")) {
stop("dtrain must be of xgb.DMatrix class")
}
handle <- xgb.get.handle(bst)

if (is.null(obj)) {
.Call(XGBoosterUpdateOneIter_R, booster_handle, as.integer(iter), dtrain)
.Call(XGBoosterUpdateOneIter_R, handle, as.integer(iter), dtrain)
} else {
pred <- predict(
booster_handle,
bst,
dtrain,
outputmargin = TRUE,
training = TRUE,
Expand All @@ -185,7 +183,7 @@ xgb.iter.update <- function(booster_handle, dtrain, iter, obj) {
}

.Call(
XGBoosterTrainOneIter_R, booster_handle, dtrain, iter, grad, hess
XGBoosterTrainOneIter_R, handle, dtrain, iter, grad, hess
)
}
return(TRUE)
Expand All @@ -195,23 +193,22 @@ xgb.iter.update <- function(booster_handle, dtrain, iter, obj) {
# Evaluate one iteration.
# Returns a named vector of evaluation metrics
# with the names in a 'datasetname-metricname' format.
xgb.iter.eval <- function(booster_handle, watchlist, iter, feval) {
if (!identical(class(booster_handle), "xgb.Booster.handle"))
stop("class of booster_handle must be xgb.Booster.handle")
xgb.iter.eval <- function(bst, watchlist, iter, feval) {
handle <- xgb.get.handle(bst)

if (length(watchlist) == 0)
return(NULL)

evnames <- names(watchlist)
if (is.null(feval)) {
msg <- .Call(XGBoosterEvalOneIter_R, booster_handle, as.integer(iter), watchlist, as.list(evnames))
msg <- .Call(XGBoosterEvalOneIter_R, handle, as.integer(iter), watchlist, as.list(evnames))
mat <- matrix(strsplit(msg, '\\s+|:')[[1]][-1], nrow = 2)
res <- structure(as.numeric(mat[2, ]), names = mat[1, ])
} else {
res <- sapply(seq_along(watchlist), function(j) {
w <- watchlist[[j]]
## predict using all trees
preds <- predict(booster_handle, w, outputmargin = TRUE, iterationrange = c(1, 1))
preds <- predict(bst, w, outputmargin = TRUE, iterationrange = c(1, 1))
eval_res <- feval(preds, w)
out <- eval_res$value
names(out) <- paste0(evnames[j], "-", eval_res$metric)
Expand Down Expand Up @@ -363,6 +360,14 @@ NULL
#' accessible in later releases of XGBoost. To ensure that your model can be accessed in future
#' releases of XGBoost, use \code{\link{xgb.save}} or \code{\link{xgb.save.raw}} instead.
#'
#' Currently, it is not possible to use R serializers like `readRDS` to load an XGBoost. model
#' saved with an XGBoost. version lower than 2.1.0, and it's not possible to load an XGBoost. model
#' saved with R serializers like `readRDS` under XGBoost. version 2.1.0 when using an older version
#' of XGBoost.
hcho3 marked this conversation as resolved.
Show resolved Hide resolved
#'
#' Furthermore, note that using the package `qs` for serialization will require version 0.26 or
#' higher of said package, and will have the same compatibility restrictions as R serializers.
#'
#' @details
#' Use \code{\link{xgb.save}} to save the XGBoost model as a stand-alone file. You may opt into
#' the JSON format by specifying the JSON extension. To read the model back, use
Expand Down
Loading
Loading