Skip to content

Commit

Permalink
[R] Refactor field logic for dmatrix (#9901)
Browse files Browse the repository at this point in the history
  • Loading branch information
david-cortes authored Dec 18, 2023
1 parent 0edd600 commit ff3d82c
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 16 deletions.
1 change: 1 addition & 0 deletions R-package/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ export(setinfo)
export(slice)
export(xgb.Booster.complete)
export(xgb.DMatrix)
export(xgb.DMatrix.hasinfo)
export(xgb.DMatrix.save)
export(xgb.attr)
export(xgb.attributes)
Expand Down
70 changes: 56 additions & 14 deletions R-package/R/xgb.DMatrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,10 @@ xgb.DMatrix <- function(
}

dmat <- handle
attributes(dmat) <- list(class = "xgb.DMatrix")
attributes(dmat) <- list(
class = "xgb.DMatrix",
fields = new.env()
)

if (!is.null(label)) {
setinfo(dmat, "label", label)
Expand Down Expand Up @@ -199,6 +202,35 @@ xgb.DMatrix <- function(
return(dmat)
}

#' @title Check whether DMatrix object has a field
#' @description Checks whether an xgb.DMatrix object has a given field assigned to
#' it, such as weights, labels, etc.
#' @param object The DMatrix object to check for the given \code{info} field.
#' @param info The field to check for presence or absence in \code{object}.
#' @seealso \link{xgb.DMatrix}, \link{getinfo.xgb.DMatrix}, \link{setinfo.xgb.DMatrix}
#' @examples
#' library(xgboost)
#' x <- matrix(1:10, nrow = 5)
#' dm <- xgb.DMatrix(x, nthread = 1)
#'
#' # 'dm' so far doesn't have any fields set
#' xgb.DMatrix.hasinfo(dm, "label")
#'
#' # Fields can be added after construction
#' setinfo(dm, "label", 1:5)
#' xgb.DMatrix.hasinfo(dm, "label")
#' @export
xgb.DMatrix.hasinfo <- function(object, info) {
if (!inherits(object, "xgb.DMatrix")) {
stop("Object is not an 'xgb.DMatrix'.")
}
if (.Call(XGCheckNullPtr_R, object)) {
warning("xgb.DMatrix object is invalid. Must be constructed again.")
return(FALSE)
}
return(NVL(attr(object, "fields")[[info]], FALSE))
}


# get dmatrix from data, label
# internal helper method
Expand Down Expand Up @@ -389,7 +421,7 @@ getinfo.xgb.DMatrix <- function(object, name, ...) {
#' @param object Object of class "xgb.DMatrix"
#' @param name the name of the field to get
#' @param info the specific field of information to set
#' @param ... other parameters
#' @param ... Not used.
#'
#' @details
#' See the documentation for \link{xgb.DMatrix} for possible fields that can be set
Expand Down Expand Up @@ -418,26 +450,32 @@ setinfo <- function(object, ...) UseMethod("setinfo")
#' @rdname setinfo
#' @export
setinfo.xgb.DMatrix <- function(object, name, info, ...) {
.internal.setinfo.xgb.DMatrix(object, name, info, ...)
attr(object, "fields")[[name]] <- TRUE
return(TRUE)
}

.internal.setinfo.xgb.DMatrix <- function(object, name, info, ...) {
if (name == "label") {
if (NROW(info) != nrow(object))
stop("The length of labels must equal to the number of rows in the input data")
.Call(XGDMatrixSetInfo_R, object, name, info)
return(TRUE)
}
if (name == "label_lower_bound") {
if (length(info) != nrow(object))
if (NROW(info) != nrow(object))
stop("The length of lower-bound labels must equal to the number of rows in the input data")
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info))
.Call(XGDMatrixSetInfo_R, object, name, info)
return(TRUE)
}
if (name == "label_upper_bound") {
if (length(info) != nrow(object))
if (NROW(info) != nrow(object))
stop("The length of upper-bound labels must equal to the number of rows in the input data")
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info))
.Call(XGDMatrixSetInfo_R, object, name, info)
return(TRUE)
}
if (name == "weight") {
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info))
.Call(XGDMatrixSetInfo_R, object, name, info)
return(TRUE)
}
if (name == "base_margin") {
Expand All @@ -447,20 +485,20 @@ setinfo.xgb.DMatrix <- function(object, name, info, ...) {
if (name == "group") {
if (sum(info) != nrow(object))
stop("The sum of groups must equal to the number of rows in the input data")
.Call(XGDMatrixSetInfo_R, object, name, as.integer(info))
.Call(XGDMatrixSetInfo_R, object, name, info)
return(TRUE)
}
if (name == "qid") {
if (NROW(info) != nrow(object))
stop("The length of qid assignments must equal to the number of rows in the input data")
.Call(XGDMatrixSetInfo_R, object, name, as.integer(info))
.Call(XGDMatrixSetInfo_R, object, name, info)
return(TRUE)
}
if (name == "feature_weights") {
if (length(info) != ncol(object)) {
if (NROW(info) != ncol(object)) {
stop("The number of feature weights must equal to the number of columns in the input data")
}
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info))
.Call(XGDMatrixSetInfo_R, object, name, info)
return(TRUE)
}

Expand Down Expand Up @@ -568,11 +606,15 @@ slice.xgb.DMatrix <- function(object, idxset, ...) {
#' @method print xgb.DMatrix
#' @export
print.xgb.DMatrix <- function(x, verbose = FALSE, ...) {
if (.Call(XGCheckNullPtr_R, x)) {
cat("INVALID xgb.DMatrix object. Must be constructed anew.\n")
return(invisible(x))
}
cat('xgb.DMatrix dim:', nrow(x), 'x', ncol(x), ' info: ')
infos <- character(0)
if (length(getinfo(x, 'label')) > 0) infos <- 'label'
if (length(getinfo(x, 'weight')) > 0) infos <- c(infos, 'weight')
if (length(getinfo(x, 'base_margin')) > 0) infos <- c(infos, 'base_margin')
if (xgb.DMatrix.hasinfo(x, 'label')) infos <- 'label'
if (xgb.DMatrix.hasinfo(x, 'weight')) infos <- c(infos, 'weight')
if (xgb.DMatrix.hasinfo(x, 'base_margin')) infos <- c(infos, 'base_margin')
if (length(infos) == 0) infos <- 'NA'
cat(infos)
cnames <- colnames(x)
Expand Down
5 changes: 4 additions & 1 deletion R-package/R/xgb.cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
early_stopping_rounds = NULL, maximize = NULL, callbacks = list(), ...) {

check.deprecation(...)
if (inherits(data, "xgb.DMatrix") && .Call(XGCheckNullPtr_R, data)) {
stop("'data' is an invalid 'xgb.DMatrix' object. Must be constructed again.")
}

params <- check.booster.params(params, ...)
# TODO: should we deprecate the redundant 'metrics' parameter?
Expand All @@ -136,7 +139,7 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
check.custom.eval()

# Check the labels
if ((inherits(data, 'xgb.DMatrix') && is.null(getinfo(data, 'label'))) ||
if ((inherits(data, 'xgb.DMatrix') && !xgb.DMatrix.hasinfo(data, 'label')) ||
(!inherits(data, 'xgb.DMatrix') && is.null(label))) {
stop("Labels must be provided for CV either through xgb.DMatrix, or through 'label=' when 'data' is matrix")
} else if (inherits(data, 'xgb.DMatrix')) {
Expand Down
2 changes: 1 addition & 1 deletion R-package/man/setinfo.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 32 additions & 0 deletions R-package/man/xgb.DMatrix.hasinfo.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit ff3d82c

Please sign in to comment.