Skip to content

Commit

Permalink
[R] allow using seed with regular RNG (#10029)
Browse files Browse the repository at this point in the history
  • Loading branch information
david-cortes authored Feb 4, 2024
1 parent 662854c commit a730c7e
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 15 deletions.
9 changes: 7 additions & 2 deletions R-package/R/xgb.train.R
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,11 @@
#' Number of threads can also be manually specified via the \code{nthread}
#' parameter.
#'
#' While in other interfaces, the default random seed defaults to zero, in R, if a parameter `seed`
#' is not manually supplied, it will generate a random seed through R's own random number generator,
#' whose seed in turn is controllable through `set.seed`. If `seed` is passed, it will override the
#' RNG from R.
#'
#' The evaluation metric is chosen automatically by XGBoost (according to the objective)
#' when the \code{eval_metric} parameter is not provided.
#' User may set one or several \code{eval_metric} parameters.
Expand Down Expand Up @@ -363,8 +368,8 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(),
# Sort the callbacks into categories
cb <- categorize.callbacks(callbacks)
params['validate_parameters'] <- TRUE
if (!is.null(params[['seed']])) {
warning("xgb.train: `seed` is ignored in R package. Use `set.seed()` instead.")
if (!("seed" %in% names(params))) {
params[["seed"]] <- sample(.Machine$integer.max, size = 1)
}

# The tree updating process would need slightly different handling
Expand Down
5 changes: 5 additions & 0 deletions R-package/man/xgb.train.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 0 additions & 10 deletions R-package/src/xgboost_custom.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,6 @@ double LogGamma(double v) {
return lgammafn(v);
}
#endif // !defined(XGBOOST_USE_CUDA)
// customize random engine.
void CustomGlobalRandomEngine::seed(CustomGlobalRandomEngine::result_type val) {
// ignore the seed
}

// use R's PRNG to replacd
CustomGlobalRandomEngine::result_type
CustomGlobalRandomEngine::operator()() {
return static_cast<result_type>(
std::floor(unif_rand() * CustomGlobalRandomEngine::max()));
}
} // namespace common
} // namespace xgboost
63 changes: 63 additions & 0 deletions R-package/tests/testthat/test_basic.R
Original file line number Diff line number Diff line change
Expand Up @@ -778,3 +778,66 @@ test_that("DMatrix field are set to booster when training", {
expect_equal(getinfo(model_feature_types, "feature_type"), c("q", "c", "q"))
expect_equal(getinfo(model_both, "feature_type"), c("q", "c", "q"))
})

test_that("Seed in params override PRNG from R", {
set.seed(123)
model1 <- xgb.train(
data = xgb.DMatrix(
agaricus.train$data,
label = agaricus.train$label, nthread = 1L
),
params = list(
objective = "binary:logistic",
max_depth = 3L,
subsample = 0.1,
colsample_bytree = 0.1,
seed = 111L
),
nrounds = 3L
)

set.seed(456)
model2 <- xgb.train(
data = xgb.DMatrix(
agaricus.train$data,
label = agaricus.train$label, nthread = 1L
),
params = list(
objective = "binary:logistic",
max_depth = 3L,
subsample = 0.1,
colsample_bytree = 0.1,
seed = 111L
),
nrounds = 3L
)

expect_equal(
xgb.save.raw(model1, raw_format = "json"),
xgb.save.raw(model2, raw_format = "json")
)

set.seed(123)
model3 <- xgb.train(
data = xgb.DMatrix(
agaricus.train$data,
label = agaricus.train$label, nthread = 1L
),
params = list(
objective = "binary:logistic",
max_depth = 3L,
subsample = 0.1,
colsample_bytree = 0.1,
seed = 222L
),
nrounds = 3L
)
expect_false(
isTRUE(
all.equal(
xgb.save.raw(model1, raw_format = "json"),
xgb.save.raw(model3, raw_format = "json")
)
)
)
})
2 changes: 1 addition & 1 deletion doc/parameter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ Specify the learning task and the corresponding learning objective. The objectiv

* ``seed`` [default=0]

- Random number seed. This parameter is ignored in R package, use `set.seed()` instead.
- Random number seed. In the R package, if not specified, instead of defaulting to seed 'zero', will take a random seed through R's own RNG engine.

* ``seed_per_iteration`` [default= ``false``]

Expand Down
2 changes: 1 addition & 1 deletion include/xgboost/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
* \brief Whether to customize global PRNG.
*/
#ifndef XGBOOST_CUSTOMIZE_GLOBAL_PRNG
#define XGBOOST_CUSTOMIZE_GLOBAL_PRNG XGBOOST_STRICT_R_MODE
#define XGBOOST_CUSTOMIZE_GLOBAL_PRNG 0
#endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG

/*!
Expand Down
2 changes: 1 addition & 1 deletion src/common/random.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace xgboost::common {
*/
using RandomEngine = std::mt19937;

#if XGBOOST_CUSTOMIZE_GLOBAL_PRNG
#if defined(XGBOOST_CUSTOMIZE_GLOBAL_PRNG) && XGBOOST_CUSTOMIZE_GLOBAL_PRNG == 1
/*!
* \brief An customized random engine, used to be plugged in PRNG from other systems.
* The implementation of this library is not provided by xgboost core library.
Expand Down

0 comments on commit a730c7e

Please sign in to comment.