Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R] On-demand serialization + standardization of attributes #9924

Merged
merged 53 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
09694ac
on-demand serialization, refactor of attributes
david-cortes Dec 24, 2023
ad6490b
solve merge conflicts
david-cortes Dec 26, 2023
27bbdbc
export function for getting booster rounds
david-cortes Dec 26, 2023
88dd947
linter
david-cortes Dec 26, 2023
4a3b5e2
fix incorrect qualifiers
david-cortes Dec 26, 2023
e2331c3
Merge branch 'master' into altrep
david-cortes Dec 26, 2023
147e1cd
remove all references to caret package
david-cortes Dec 26, 2023
e012cce
fix example
david-cortes Dec 26, 2023
f444812
misc fixes
david-cortes Dec 26, 2023
2f30031
allow unsetting booster info
david-cortes Dec 26, 2023
6d4ad8b
remove unused argument
david-cortes Dec 26, 2023
b0054be
more fixes
david-cortes Dec 26, 2023
4050b6f
missing import
david-cortes Dec 26, 2023
2e16f73
swap 'static' with 'namespace'
david-cortes Dec 27, 2023
70affd5
improve wording on compatibility note
david-cortes Dec 27, 2023
af6cdbf
fix non-executed tests and potentially incorrect 'niter_init'
david-cortes Dec 28, 2023
22d4dd7
linter
david-cortes Dec 29, 2023
0465f57
solve merge conflicts
david-cortes Dec 30, 2023
74d5d55
more doc specificity about nrounds reset
david-cortes Dec 30, 2023
1bb74d8
correct function name
david-cortes Dec 30, 2023
c5d711f
solve merge conflicts
david-cortes Dec 31, 2023
b5ec14e
corrections after merge conflicts
david-cortes Dec 31, 2023
ae0de6d
more corrections after merge conflict
david-cortes Dec 31, 2023
1a3d9f7
solve merge conflicts
david-cortes Jan 3, 2024
041dd2f
updates for new default serialization format
david-cortes Jan 8, 2024
7f39bb0
update name for nrounds getter
david-cortes Jan 8, 2024
02c312c
remove in-place training continuation
david-cortes Jan 8, 2024
b4d59f7
change unserialize -> load.raw
david-cortes Jan 8, 2024
24e256a
use R lists instead of JSON text for xgb.config
david-cortes Jan 8, 2024
c97dc1a
remove internal function for nrounds getter
david-cortes Jan 8, 2024
5f8dea5
use _R suffix for all C functions specific to R
david-cortes Jan 8, 2024
8e29769
add test for C and R attributes with saveRDS
david-cortes Jan 8, 2024
9f81e20
add variable.names method for booster
david-cortes Jan 8, 2024
65197e1
add comment about supressed warning
david-cortes Jan 8, 2024
1b58e1b
solve merge conflicts
david-cortes Jan 8, 2024
7dc9b96
update comment
david-cortes Jan 8, 2024
11b213e
update serializers in vignette
david-cortes Jan 8, 2024
c72e663
update vignettes
david-cortes Jan 8, 2024
df88ad9
remove xgb.serialize and xgb.unserialize
david-cortes Jan 9, 2024
74f7f0c
Update R-package/R/xgb.save.R
david-cortes Jan 9, 2024
1ede32e
update docs
david-cortes Jan 9, 2024
37d6b1e
remove 'keep_extra_attributes'
david-cortes Jan 9, 2024
43d938b
remove .Rnw file
david-cortes Jan 9, 2024
692e5a5
add note about booster's R parameters
david-cortes Jan 10, 2024
c161999
user SerializeToBuffer for internal serialization
david-cortes Jan 10, 2024
feedce5
add test for serialization of config
david-cortes Jan 10, 2024
a02abfc
check more attributes
david-cortes Jan 10, 2024
3285ed6
rewrite compatibility note for serialization
david-cortes Jan 10, 2024
8082256
improve wording
david-cortes Jan 10, 2024
6fa7937
update note about attributes in xgb.save
david-cortes Jan 10, 2024
e02ed8f
Update R-package/R/utils.R
david-cortes Jan 10, 2024
ff70221
Update R-package/R/utils.R
david-cortes Jan 10, 2024
d133258
rebuild docs
david-cortes Jan 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions R-package/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@

S3method("[",xgb.DMatrix)
S3method("dimnames<-",xgb.DMatrix)
S3method(coef,xgb.Booster)
S3method(dim,xgb.DMatrix)
S3method(dimnames,xgb.DMatrix)
S3method(getinfo,xgb.Booster)
S3method(getinfo,xgb.DMatrix)
S3method(predict,xgb.Booster)
S3method(predict,xgb.Booster.handle)
S3method(print,xgb.Booster)
S3method(print,xgb.DMatrix)
S3method(print,xgb.cv.synchronous)
S3method(setinfo,xgb.Booster)
S3method(setinfo,xgb.DMatrix)
S3method(slice,xgb.DMatrix)
S3method(variable.names,xgb.Booster)
export("xgb.attr<-")
export("xgb.attributes<-")
export("xgb.config<-")
Expand All @@ -26,13 +29,13 @@ export(cb.save.model)
export(getinfo)
export(setinfo)
export(slice)
export(xgb.Booster.complete)
export(xgb.DMatrix)
export(xgb.DMatrix.hasinfo)
export(xgb.DMatrix.save)
export(xgb.attr)
export(xgb.attributes)
export(xgb.config)
export(xgb.copy.Booster)
export(xgb.create.features)
export(xgb.cv)
export(xgb.dump)
Expand All @@ -41,10 +44,12 @@ export(xgb.get.DMatrix.data)
export(xgb.get.DMatrix.num.non.missing)
export(xgb.get.DMatrix.qcut)
export(xgb.get.config)
export(xgb.get.num.boosted.rounds)
export(xgb.ggplot.deepness)
export(xgb.ggplot.importance)
export(xgb.ggplot.shap.summary)
export(xgb.importance)
export(xgb.is.same.Booster)
export(xgb.load)
export(xgb.load.raw)
export(xgb.model.dt.tree)
Expand All @@ -56,10 +61,8 @@ export(xgb.plot.shap.summary)
export(xgb.plot.tree)
export(xgb.save)
export(xgb.save.raw)
export(xgb.serialize)
export(xgb.set.config)
export(xgb.train)
export(xgb.unserialize)
export(xgboost)
import(methods)
importClassesFrom(Matrix,dgCMatrix)
Expand Down Expand Up @@ -88,8 +91,10 @@ importFrom(graphics,title)
importFrom(jsonlite,fromJSON)
importFrom(jsonlite,toJSON)
importFrom(methods,new)
importFrom(stats,coef)
importFrom(stats,median)
importFrom(stats,predict)
importFrom(stats,variable.names)
importFrom(utils,head)
importFrom(utils,object.size)
importFrom(utils,str)
Expand Down
81 changes: 46 additions & 35 deletions R-package/R/callbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ cb.reset.parameters <- function(new_params) {
})

if (!is.null(env$bst)) {
xgb.parameters(env$bst$handle) <- pars
xgb.parameters(env$bst) <- pars
} else {
for (fd in env$bst_folds)
xgb.parameters(fd$bst) <- pars
Expand Down Expand Up @@ -333,13 +333,13 @@ cb.early.stop <- function(stopping_rounds, maximize = FALSE,
if (!is.null(env$bst)) {
if (!inherits(env$bst, 'xgb.Booster'))
stop("'bst' in the parent frame must be an 'xgb.Booster'")
if (!is.null(best_score <- xgb.attr(env$bst$handle, 'best_score'))) {
if (!is.null(best_score <- xgb.attr(env$bst, 'best_score'))) {
best_score <<- as.numeric(best_score)
best_iteration <<- as.numeric(xgb.attr(env$bst$handle, 'best_iteration')) + 1
best_msg <<- as.numeric(xgb.attr(env$bst$handle, 'best_msg'))
best_iteration <<- as.numeric(xgb.attr(env$bst, 'best_iteration')) + 1
best_msg <<- as.numeric(xgb.attr(env$bst, 'best_msg'))
} else {
xgb.attributes(env$bst$handle) <- list(best_iteration = best_iteration - 1,
best_score = best_score)
xgb.attributes(env$bst) <- list(best_iteration = best_iteration - 1,
best_score = best_score)
}
} else if (is.null(env$bst_folds) || is.null(env$basket)) {
stop("Parent frame has neither 'bst' nor ('bst_folds' and 'basket')")
Expand All @@ -348,7 +348,7 @@ cb.early.stop <- function(stopping_rounds, maximize = FALSE,

finalizer <- function(env) {
if (!is.null(env$bst)) {
attr_best_score <- as.numeric(xgb.attr(env$bst$handle, 'best_score'))
attr_best_score <- as.numeric(xgb.attr(env$bst, 'best_score'))
if (best_score != attr_best_score) {
# If the difference is too big, throw an error
if (abs(best_score - attr_best_score) >= 1e-14) {
Expand All @@ -358,9 +358,9 @@ cb.early.stop <- function(stopping_rounds, maximize = FALSE,
# If the difference is due to floating-point truncation, update best_score
best_score <- attr_best_score
}
env$bst$best_iteration <- best_iteration
env$bst$best_ntreelimit <- best_ntreelimit
env$bst$best_score <- best_score
xgb.attr(env$bst, "best_iteration") <- best_iteration
xgb.attr(env$bst, "best_ntreelimit") <- best_ntreelimit
xgb.attr(env$bst, "best_score") <- best_score
} else {
env$basket$best_iteration <- best_iteration
env$basket$best_ntreelimit <- best_ntreelimit
Expand Down Expand Up @@ -412,11 +412,15 @@ cb.early.stop <- function(stopping_rounds, maximize = FALSE,
#' @param save_period save the model to disk after every
#' \code{save_period} iterations; 0 means save the model at the end.
#' @param save_name the name or path for the saved model file.
#'
#' Note that the format of the model being saved is determined by the file
#' extension specified here (see \link{xgb.save} for details about how it works).
#'
#' It can contain a \code{\link[base]{sprintf}} formatting specifier
#' to include the integer iteration number in the file name.
#' E.g., with \code{save_name} = 'xgboost_%04d.model',
#' the file saved at iteration 50 would be named "xgboost_0050.model".
#'
#' E.g., with \code{save_name} = 'xgboost_%04d.ubj',
#' the file saved at iteration 50 would be named "xgboost_0050.ubj".
#' @seealso \link{xgb.save}
#' @details
#' This callback function allows to save an xgb-model file, either periodically after each \code{save_period}'s or at the end.
#'
Expand All @@ -430,7 +434,7 @@ cb.early.stop <- function(stopping_rounds, maximize = FALSE,
#' \code{\link{callbacks}}
#'
#' @export
cb.save.model <- function(save_period = 0, save_name = "xgboost.model") {
cb.save.model <- function(save_period = 0, save_name = "xgboost.ubj") {

if (save_period < 0)
stop("'save_period' cannot be negative")
Expand All @@ -440,8 +444,13 @@ cb.save.model <- function(save_period = 0, save_name = "xgboost.model") {
stop("'save_model' callback requires the 'bst' booster object in its calling frame")

if ((save_period > 0 && (env$iteration - env$begin_iteration) %% save_period == 0) ||
(save_period == 0 && env$iteration == env$end_iteration))
xgb.save(env$bst, sprintf(save_name, env$iteration))
(save_period == 0 && env$iteration == env$end_iteration)) {
# Note: this throws a warning if the name doesn't have anything to format through 'sprintf'
suppressWarnings({
trivialfis marked this conversation as resolved.
Show resolved Hide resolved
save_name <- sprintf(save_name, env$iteration)
})
xgb.save(env$bst, save_name)
}
}
attr(callback, 'call') <- match.call()
attr(callback, 'name') <- 'cb.save.model'
Expand Down Expand Up @@ -512,8 +521,7 @@ cb.cv.predict <- function(save_models = FALSE) {
env$basket$pred <- pred
if (save_models) {
env$basket$models <- lapply(env$bst_folds, function(fd) {
xgb.attr(fd$bst, 'niter') <- env$end_iteration - 1
xgb.Booster.complete(xgb.handleToBooster(handle = fd$bst, raw = NULL), saveraw = TRUE)
return(fd$bst)
})
}
}
Expand Down Expand Up @@ -665,7 +673,7 @@ cb.gblinear.history <- function(sparse = FALSE) {
} else { # xgb.cv:
cf <- vector("list", length(env$bst_folds))
for (i in seq_along(env$bst_folds)) {
dmp <- xgb.dump(xgb.handleToBooster(handle = env$bst_folds[[i]]$bst, raw = NULL))
dmp <- xgb.dump(env$bst_folds[[i]]$bst)
cf[[i]] <- as.numeric(grep('(booster|bias|weigh)', dmp, invert = TRUE, value = TRUE))
if (sparse) cf[[i]] <- as(cf[[i]], "sparseVector")
}
Expand All @@ -685,14 +693,19 @@ cb.gblinear.history <- function(sparse = FALSE) {
callback
}

#' Extract gblinear coefficients history.
#'
#' A helper function to extract the matrix of linear coefficients' history
#' @title Extract gblinear coefficients history.
#' @description A helper function to extract the matrix of linear coefficients' history
#' from a gblinear model created while using the \code{cb.gblinear.history()}
#' callback.
#' @details Note that this is an R-specific function that relies on R attributes that
#' are not saved when using xgboost's own serialization functions like \link{xgb.load}
#' or \link{xgb.load.raw}.
#'
#' In order for a serialized model to be accepted by tgis function, one must use R
#' serializers such as \link{saveRDS}.
#' @param model either an \code{xgb.Booster} or a result of \code{xgb.cv()}, trained
#' using the \code{cb.gblinear.history()} callback.
#' using the \code{cb.gblinear.history()} callback, but \bold{not} a booster
#' loaded from \link{xgb.load} or \link{xgb.load.raw}.
#' @param class_index zero-based class index to extract the coefficients for only that
#' specific class in a multinomial multiclass model. When it is NULL, all the
#' coefficients are returned. Has no effect in non-multiclass models.
Expand All @@ -713,20 +726,18 @@ xgb.gblinear.history <- function(model, class_index = NULL) {
stop("model must be an object of either xgb.Booster or xgb.cv.synchronous class")
is_cv <- inherits(model, "xgb.cv.synchronous")

if (is.null(model[["callbacks"]]) || is.null(model$callbacks[["cb.gblinear.history"]]))
if (is_cv) {
callbacks <- model$callbacks
} else {
callbacks <- attributes(model)$callbacks
}

if (is.null(callbacks) || is.null(callbacks$cb.gblinear.history))
stop("model must be trained while using the cb.gblinear.history() callback")

if (!is_cv) {
# extract num_class & num_feat from the internal model
dmp <- xgb.dump(model)
if (length(dmp) < 2 || dmp[2] != "bias:")
stop("It does not appear to be a gblinear model")
dmp <- dmp[-c(1, 2)]
n <- which(dmp == 'weight:')
if (length(n) != 1)
stop("It does not appear to be a gblinear model")
num_class <- n - 1
num_feat <- (length(dmp) - 4) / num_class
num_class <- xgb.num_class(model)
num_feat <- xgb.num_feature(model)
} else {
# in case of CV, the object is expected to have this info
if (model$params$booster != "gblinear")
Expand All @@ -742,7 +753,7 @@ xgb.gblinear.history <- function(model, class_index = NULL) {
(class_index[1] < 0 || class_index[1] >= num_class))
stop("class_index has to be within [0,", num_class - 1, "]")

coef_path <- environment(model$callbacks$cb.gblinear.history)[["coefs"]]
coef_path <- environment(callbacks$cb.gblinear.history)[["coefs"]]
if (!is.null(class_index) && num_class > 1) {
coef_path <- if (is.list(coef_path)) {
lapply(coef_path,
Expand Down
33 changes: 17 additions & 16 deletions R-package/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -148,19 +148,17 @@ check.custom.eval <- function(env = parent.frame()) {


# Update a booster handle for an iteration with dtrain data
xgb.iter.update <- function(booster_handle, dtrain, iter, obj) {
if (!identical(class(booster_handle), "xgb.Booster.handle")) {
stop("booster_handle must be of xgb.Booster.handle class")
}
xgb.iter.update <- function(bst, dtrain, iter, obj) {
if (!inherits(dtrain, "xgb.DMatrix")) {
stop("dtrain must be of xgb.DMatrix class")
}
handle <- xgb.get.handle(bst)

if (is.null(obj)) {
.Call(XGBoosterUpdateOneIter_R, booster_handle, as.integer(iter), dtrain)
.Call(XGBoosterUpdateOneIter_R, handle, as.integer(iter), dtrain)
} else {
pred <- predict(
booster_handle,
bst,
dtrain,
outputmargin = TRUE,
training = TRUE,
Expand All @@ -185,7 +183,7 @@ xgb.iter.update <- function(booster_handle, dtrain, iter, obj) {
}

.Call(
XGBoosterTrainOneIter_R, booster_handle, dtrain, iter, grad, hess
XGBoosterTrainOneIter_R, handle, dtrain, iter, grad, hess
)
}
return(TRUE)
Expand All @@ -195,23 +193,22 @@ xgb.iter.update <- function(booster_handle, dtrain, iter, obj) {
# Evaluate one iteration.
# Returns a named vector of evaluation metrics
# with the names in a 'datasetname-metricname' format.
xgb.iter.eval <- function(booster_handle, watchlist, iter, feval) {
if (!identical(class(booster_handle), "xgb.Booster.handle"))
stop("class of booster_handle must be xgb.Booster.handle")
xgb.iter.eval <- function(bst, watchlist, iter, feval) {
handle <- xgb.get.handle(bst)

if (length(watchlist) == 0)
return(NULL)

evnames <- names(watchlist)
if (is.null(feval)) {
msg <- .Call(XGBoosterEvalOneIter_R, booster_handle, as.integer(iter), watchlist, as.list(evnames))
msg <- .Call(XGBoosterEvalOneIter_R, handle, as.integer(iter), watchlist, as.list(evnames))
mat <- matrix(strsplit(msg, '\\s+|:')[[1]][-1], nrow = 2)
res <- structure(as.numeric(mat[2, ]), names = mat[1, ])
} else {
res <- sapply(seq_along(watchlist), function(j) {
w <- watchlist[[j]]
## predict using all trees
preds <- predict(booster_handle, w, outputmargin = TRUE, iterationrange = c(1, 1))
preds <- predict(bst, w, outputmargin = TRUE, iterationrange = c(1, 1))
eval_res <- feval(preds, w)
out <- eval_res$value
names(out) <- paste0(evnames[j], "-", eval_res$metric)
Expand Down Expand Up @@ -363,6 +360,14 @@ NULL
#' accessible in later releases of XGBoost. To ensure that your model can be accessed in future
#' releases of XGBoost, use \code{\link{xgb.save}} or \code{\link{xgb.save.raw}} instead.
#'
#' Note that XGBoost models in R starting from version `2.1.0` and onwards, and XGBoost models
#' before version `2.1.0`; have a very different R object structure and are incompatible with
#' each other. Hence, models that were saved with R serializers live `saveRDS` or `save` before
#' version `2.1.0` will not work with latter `xgboost` versions and vice versa.
#'
#' Furthermore, note that using the package `qs` for serialization will require version 0.26 or
#' higher of said package, and will have the same compatibility restrictions as R serializers.
#'
#' @details
#' Use \code{\link{xgb.save}} to save the XGBoost model as a stand-alone file. You may opt into
#' the JSON format by specifying the JSON extension. To read the model back, use
Expand All @@ -374,10 +379,6 @@ NULL
#' The \code{\link{xgb.save.raw}} function is useful if you'd like to persist the XGBoost model
#' as part of another R object.
#'
#' Note: Do not use \code{\link{xgb.serialize}} to store models long-term. It persists not only the
#' model but also internal configurations and parameters, and its format is not stable across
#' multiple XGBoost versions. Use \code{\link{xgb.serialize}} only for checkpointing.
#'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the replacement for xgb.serialize? Should users use saveRDS to fully preserve all attributes?

#' For more details and explanation about model persistence and archival, consult the page
#' \url{https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html}.
#'
Expand Down
Loading
Loading