Skip to content

Commit

Permalink
[R] remove default values in internal utility functions (#9457)
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb authored Aug 10, 2023
1 parent 9dbb714 commit 44bd298
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 9 deletions.
8 changes: 4 additions & 4 deletions R-package/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ check.custom.eval <- function(env = parent.frame()) {


# Update a booster handle for an iteration with dtrain data
xgb.iter.update <- function(booster_handle, dtrain, iter, obj = NULL) {
xgb.iter.update <- function(booster_handle, dtrain, iter, obj) {
if (!identical(class(booster_handle), "xgb.Booster.handle")) {
stop("booster_handle must be of xgb.Booster.handle class")
}
Expand All @@ -163,7 +163,7 @@ xgb.iter.update <- function(booster_handle, dtrain, iter, obj = NULL) {
# Evaluate one iteration.
# Returns a named vector of evaluation metrics
# with the names in a 'datasetname-metricname' format.
xgb.iter.eval <- function(booster_handle, watchlist, iter, feval = NULL) {
xgb.iter.eval <- function(booster_handle, watchlist, iter, feval) {
if (!identical(class(booster_handle), "xgb.Booster.handle"))
stop("class of booster_handle must be xgb.Booster.handle")

Expand Down Expand Up @@ -234,7 +234,7 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
y <- factor(y)
}
}
folds <- xgb.createFolds(y, nfold)
folds <- xgb.createFolds(y = y, k = nfold)
} else {
# make simple non-stratified folds
kstep <- length(rnd_idx) %/% nfold
Expand All @@ -251,7 +251,7 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
# Creates CV folds stratified by the values of y.
# It was borrowed from caret::createFolds and simplified
# by always returning an unnamed list of fold indices.
xgb.createFolds <- function(y, k = 10) {
xgb.createFolds <- function(y, k) {
if (is.numeric(y)) {
## Group the numeric data based on their magnitudes
## and sample within those groups.
Expand Down
14 changes: 12 additions & 2 deletions R-package/R/xgb.cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,18 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
for (f in cb$pre_iter) f()

msg <- lapply(bst_folds, function(fd) {
xgb.iter.update(fd$bst, fd$dtrain, iteration - 1, obj)
xgb.iter.eval(fd$bst, fd$watchlist, iteration - 1, feval)
xgb.iter.update(
booster_handle = fd$bst,
dtrain = fd$dtrain,
iter = iteration - 1,
obj = obj
)
xgb.iter.eval(
booster_handle = fd$bst,
watchlist = fd$watchlist,
iter = iteration - 1,
feval = feval
)
})
msg <- simplify2array(msg)
bst_evaluation <- rowMeans(msg)
Expand Down
17 changes: 14 additions & 3 deletions R-package/R/xgb.train.R
Original file line number Diff line number Diff line change
Expand Up @@ -390,10 +390,21 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(),

for (f in cb$pre_iter) f()

xgb.iter.update(bst$handle, dtrain, iteration - 1, obj)
xgb.iter.update(
booster_handle = bst$handle,
dtrain = dtrain,
iter = iteration - 1,
obj = obj
)

if (length(watchlist) > 0)
bst_evaluation <- xgb.iter.eval(bst$handle, watchlist, iteration - 1, feval) # nolint: object_usage_linter
if (length(watchlist) > 0) {
bst_evaluation <- xgb.iter.eval( # nolint: object_usage_linter
booster_handle = bst$handle,
watchlist = watchlist,
iter = iteration - 1,
feval = feval
)
}

xgb.attr(bst$handle, 'niter') <- iteration - 1

Expand Down

0 comments on commit 44bd298

Please sign in to comment.