Skip to content

Commit

Permalink
support base_margin for regr.xgboost
Browse files Browse the repository at this point in the history
  • Loading branch information
bblodfon committed Aug 21, 2024
1 parent afb3ac7 commit d16ea08
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions R/LearnerRegrXgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ LearnerRegrXgboost = R6Class("LearnerRegrXgboost",
alpha = p_dbl(0, default = 0, tags = "train"),
approxcontrib = p_lgl(default = FALSE, tags = "predict"),
base_score = p_dbl(default = 0.5, tags = "train"),
base_margin = p_uty(default = NULL, special_vals = list(NULL), tags = "train"),
booster = p_fct(c("gbtree", "gblinear", "dart"), default = "gbtree", tags = "train"),
callbacks = p_uty(default = list(), tags = "train"),
colsample_bylevel = p_dbl(0, 1, default = 1, tags = "train"),
Expand Down Expand Up @@ -206,6 +207,13 @@ LearnerRegrXgboost = R6Class("LearnerRegrXgboost",
xgboost::setinfo(data, "weight", task$weights$weight)
}

bm = pv$base_margin
pv$base_margin = NULL # silence xgb.train message
bm_is_feature = !is.null(bm) && is.character(bm) && (bm %in% task$feature_names)
if (bm_is_feature) {
xgboost::setinfo(data, "base_margin", task$data(cols = bm)[[1L]])
}

# the last element in the watchlist is used as the early stopping set
internal_valid_task = task$internal_valid_task
if (!is.null(pv$early_stopping_rounds) && is.null(internal_valid_task)) {
Expand All @@ -215,6 +223,9 @@ LearnerRegrXgboost = R6Class("LearnerRegrXgboost",
test_data = internal_valid_task$data(cols = task$feature_names)
test_target = internal_valid_task$data(cols = task$target_names)
test_data = xgboost::xgb.DMatrix(data = as_numeric_matrix(test_data), label = data.matrix(test_target))
if (bm_is_feature) {
xgboost::setinfo(test_data, "base_margin", internal_valid_task$data(cols = bm)[[1L]])
}
pv$watchlist = c(pv$watchlist, list(test = test_data))
}

Expand Down

0 comments on commit d16ea08

Please sign in to comment.