Skip to content

Commit

Permalink
avoid using task$data twice
Browse files Browse the repository at this point in the history
  • Loading branch information
bblodfon committed Aug 22, 2024
1 parent 239b109 commit eeda8f4
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions R/LearnerRegrXgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -201,17 +201,17 @@ LearnerRegrXgboost = R6Class("LearnerRegrXgboost",

data = task$data(cols = task$feature_names)
target = task$data(cols = task$target_names)
data = xgboost::xgb.DMatrix(data = as_numeric_matrix(data), label = data.matrix(target))
xgb_data = xgboost::xgb.DMatrix(data = as_numeric_matrix(data), label = data.matrix(target))

if ("weights" %in% task$properties) {
xgboost::setinfo(data, "weight", task$weights$weight)
xgboost::setinfo(xgb_data, "weight", task$weights$weight)
}

bm = pv$base_margin
pv$base_margin = NULL # silence xgb.train message
bm_is_feature = !is.null(bm) && is.character(bm) && (bm %in% task$feature_names)
if (bm_is_feature) {
xgboost::setinfo(data, "base_margin", task$data(cols = bm)[[1L]])
xgboost::setinfo(xgb_data, "base_margin", data[[bm]])
}

# the last element in the watchlist is used as the early stopping set
Expand All @@ -222,11 +222,11 @@ LearnerRegrXgboost = R6Class("LearnerRegrXgboost",
if (!is.null(internal_valid_task)) {
test_data = internal_valid_task$data(cols = task$feature_names)
test_target = internal_valid_task$data(cols = task$target_names)
test_data = xgboost::xgb.DMatrix(data = as_numeric_matrix(test_data), label = data.matrix(test_target))
xgb_test_data = xgboost::xgb.DMatrix(data = as_numeric_matrix(test_data), label = data.matrix(test_target))
if (bm_is_feature) {
xgboost::setinfo(test_data, "base_margin", internal_valid_task$data(cols = bm)[[1L]])
xgboost::setinfo(xgb_test_data, "base_margin", test_data[[bm]])
}
pv$watchlist = c(pv$watchlist, list(test = test_data))
pv$watchlist = c(pv$watchlist, list(test = xgb_test_data))
}

# set internal validation measure
Expand All @@ -246,7 +246,7 @@ LearnerRegrXgboost = R6Class("LearnerRegrXgboost",
pv$maximize = !measure$minimize
}

invoke(xgboost::xgb.train, data = data, .args = pv)
invoke(xgboost::xgb.train, data = xgb_data, .args = pv)
},
#' Returns the `$best_iteration` when early stopping is activated.
.predict = function(task) {
Expand Down

0 comments on commit eeda8f4

Please sign in to comment.