From eeda8f49d39b088096075c791d4956bd8e1f361d Mon Sep 17 00:00:00 2001 From: john Date: Thu, 22 Aug 2024 17:33:55 +0200 Subject: [PATCH] avoid using task$data twice --- R/LearnerRegrXgboost.R | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/R/LearnerRegrXgboost.R b/R/LearnerRegrXgboost.R index 3ddcd04c..aea51525 100644 --- a/R/LearnerRegrXgboost.R +++ b/R/LearnerRegrXgboost.R @@ -201,17 +201,17 @@ LearnerRegrXgboost = R6Class("LearnerRegrXgboost", data = task$data(cols = task$feature_names) target = task$data(cols = task$target_names) - data = xgboost::xgb.DMatrix(data = as_numeric_matrix(data), label = data.matrix(target)) + xgb_data = xgboost::xgb.DMatrix(data = as_numeric_matrix(data), label = data.matrix(target)) if ("weights" %in% task$properties) { - xgboost::setinfo(data, "weight", task$weights$weight) + xgboost::setinfo(xgb_data, "weight", task$weights$weight) } bm = pv$base_margin pv$base_margin = NULL # silence xgb.train message bm_is_feature = !is.null(bm) && is.character(bm) && (bm %in% task$feature_names) if (bm_is_feature) { - xgboost::setinfo(data, "base_margin", task$data(cols = bm)[[1L]]) + xgboost::setinfo(xgb_data, "base_margin", data[[bm]]) } # the last element in the watchlist is used as the early stopping set @@ -222,11 +222,11 @@ LearnerRegrXgboost = R6Class("LearnerRegrXgboost", if (!is.null(internal_valid_task)) { test_data = internal_valid_task$data(cols = task$feature_names) test_target = internal_valid_task$data(cols = task$target_names) - test_data = xgboost::xgb.DMatrix(data = as_numeric_matrix(test_data), label = data.matrix(test_target)) + xgb_test_data = xgboost::xgb.DMatrix(data = as_numeric_matrix(test_data), label = data.matrix(test_target)) if (bm_is_feature) { - xgboost::setinfo(test_data, "base_margin", internal_valid_task$data(cols = bm)[[1L]]) + xgboost::setinfo(xgb_test_data, "base_margin", test_data[[bm]]) } - pv$watchlist = c(pv$watchlist, list(test = test_data)) + pv$watchlist = c(pv$watchlist, list(test = xgb_test_data)) } # set internal validation measure @@ -246,7 +246,7 @@ LearnerRegrXgboost = R6Class("LearnerRegrXgboost", pv$maximize = !measure$minimize } - invoke(xgboost::xgb.train, data = data, .args = pv) + invoke(xgboost::xgb.train, data = xgb_data, .args = pv) }, #' Returns the `$best_iteration` when early stopping is activated. .predict = function(task) {