From 9d7d0482ffd2b6b244716788a1ac361a1185f50e Mon Sep 17 00:00:00 2001 From: Michael Rapp Date: Mon, 19 Aug 2024 18:25:26 +0200 Subject: [PATCH] Add template argument StatisticVector to class DenseDecomposableStatistics. --- .../boosting/statistics/statistics_decomposable_dense.hpp | 8 +++++--- .../statistics/statistics_provider_decomposable_dense.cpp | 5 +++-- .../statistics_provider_non_decomposable_dense.cpp | 3 ++- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_decomposable_dense.hpp b/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_decomposable_dense.hpp index b164d6eee1..0a178822f0 100644 --- a/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_decomposable_dense.hpp +++ b/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_decomposable_dense.hpp @@ -53,11 +53,13 @@ namespace boosting { * @tparam OutputMatrix The type of the matrix that provides access to the ground truth of the training * examples * @tparam QuantizationMatrix The type of the matrix that provides access to quantized gradients and Hessians + * @tparam StatisticVector The type of the vectors that are used to store gradients and Hessians * @tparam EvaluationMeasure The type of the evaluation that should be used to access the quality of predictions */ - template + template class DenseDecomposableStatistics final - : public AbstractDecomposableStatistics, Loss, EvaluationMeasure, IDecomposableRuleEvaluationFactory> { public: @@ -87,7 +89,7 @@ namespace boosting { const OutputMatrix& outputMatrix, std::unique_ptr statisticMatrixPtr, std::unique_ptr> scoreMatrixPtr) - : AbstractDecomposableStatistics, Loss, EvaluationMeasure, IDecomposableRuleEvaluationFactory>( std::move(quantizationMatrixPtr), std::move(lossPtr), std::move(evaluationMeasurePtr), diff --git a/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_provider_decomposable_dense.cpp b/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_provider_decomposable_dense.cpp index 01703c4843..b6b67f6924 100644 --- a/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_provider_decomposable_dense.cpp +++ b/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_provider_decomposable_dense.cpp @@ -36,8 +36,9 @@ namespace boosting { std::unique_ptr> statisticsPtr; auto denseDecomposableMatrixVisitor = [&](std::unique_ptr>>>& quantizationMatrixPtr) { - statisticsPtr = std::make_unique>>, EvaluationMeasure>>( + statisticsPtr = std::make_unique< + DenseDecomposableStatistics>>, + DenseDecomposableStatisticVector, EvaluationMeasure>>( std::move(quantizationMatrixPtr), std::move(lossPtr), std::move(evaluationMeasurePtr), ruleEvaluationFactory, outputMatrix, std::move(statisticMatrixPtr), std::move(scoreMatrixPtr)); }; diff --git a/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_provider_non_decomposable_dense.cpp b/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_provider_non_decomposable_dense.cpp index 14acfac5d0..fae99bb6cd 100644 --- a/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_provider_non_decomposable_dense.cpp +++ b/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_provider_non_decomposable_dense.cpp @@ -150,7 +150,8 @@ namespace boosting { auto denseDecomposableMatrixVisitor = [&](std::unique_ptr>>>& quantizationMatrixPtr) { statisticsPtr = std::make_unique>>, EvaluationMeasure>>( + Loss, OutputMatrix, IQuantizationMatrix>>, + DenseDecomposableStatisticVector, EvaluationMeasure>>( std::move(quantizationMatrixPtr), std::move(this->lossPtr_), std::move(this->evaluationMeasurePtr_), ruleEvaluationFactory, this->outputMatrix_, std::move(decomposableStatisticMatrixPtr), std::move(this->scoreMatrixPtr_));