diff --git a/cpp/subprojects/boosting/include/mlrl/boosting/statistics/quantization.hpp b/cpp/subprojects/boosting/include/mlrl/boosting/statistics/quantization.hpp index 41323a3c2..301b8d157 100644 --- a/cpp/subprojects/boosting/include/mlrl/boosting/statistics/quantization.hpp +++ b/cpp/subprojects/boosting/include/mlrl/boosting/statistics/quantization.hpp @@ -3,6 +3,7 @@ */ #pragma once +#include "mlrl/boosting/data/view_statistic_decomposable_bit.hpp" #include "mlrl/boosting/data/view_statistic_non_decomposable_dense.hpp" #include "mlrl/common/data/tuple.hpp" #include "mlrl/common/data/view_matrix_c_contiguous.hpp" @@ -43,6 +44,13 @@ namespace boosting { typedef std::function>>>&)> SparseDecomposableMatrixVisitor; + /** + * A visitor function for handling quantization matrices that are backed by a view of the type + * `BitDecomposableStatisticView`. + */ + typedef std::function>&)> + BitDecomposableMatrixVisitor; + /** * A visitor function for handling quantization matrices that are backed by a view of the type * `DenseNonDecomposableStatisticView`. @@ -60,6 +68,9 @@ namespace boosting { * @param sparseDecomposableMatrixVisitor An optional visitor function for handling quantization matrices * that are backed by a view of the type * `SparseSetView>` + * @param bitDecomposableMatrixVisitor An optional visitor function for handling quantization matrices + * that are backed by a view of the type + * `BitDecomposableStatisticView` * @param denseNonDecomposableMatrixVisitor An optional visitor function for handling quantization matrices * that are backed by a view of the type * `DenseNonDecomposableStatisticView` @@ -67,6 +78,7 @@ namespace boosting { virtual void visitQuantizationMatrix( std::optional denseDecomposableMatrixVisitor, std::optional sparseDecomposableMatrixVisitor, + std::optional bitDecomposableMatrixVisitor, std::optional denseNonDecomposableMatrixVisitor) = 0; }; diff --git a/cpp/subprojects/boosting/src/mlrl/boosting/statistics/quantization_no.cpp b/cpp/subprojects/boosting/src/mlrl/boosting/statistics/quantization_no.cpp index f7bf7eb4c..63120a899 100644 --- a/cpp/subprojects/boosting/src/mlrl/boosting/statistics/quantization_no.cpp +++ b/cpp/subprojects/boosting/src/mlrl/boosting/statistics/quantization_no.cpp @@ -16,6 +16,7 @@ namespace boosting { void visitQuantizationMatrix( std::optional denseDecomposableMatrixVisitor, std::optional sparseDecomposableMatrixVisitor, + std::optional bitDecomposableMatrixVisitor, std::optional denseNonDecomposableMatrixVisitor) override { if (denseDecomposableMatrixVisitor) { @@ -38,6 +39,7 @@ namespace boosting { void visitQuantizationMatrix( std::optional denseDecomposableMatrixVisitor, std::optional sparseDecomposableMatrixVisitor, + std::optional bitDecomposableMatrixVisitor, std::optional denseNonDecomposableMatrixVisitor) override { if (sparseDecomposableMatrixVisitor) { @@ -60,6 +62,7 @@ namespace boosting { void visitQuantizationMatrix( std::optional denseDecomposableMatrixVisitor, std::optional sparseDecomposableMatrixVisitor, + std::optional bitDecomposableMatrixVisitor, std::optional denseNonDecomposableMatrixVisitor) override { if (denseNonDecomposableMatrixVisitor) { diff --git a/cpp/subprojects/boosting/src/mlrl/boosting/statistics/quantization_stochastic.cpp b/cpp/subprojects/boosting/src/mlrl/boosting/statistics/quantization_stochastic.cpp index fcd1d5a4d..f4e4c72da 100644 --- a/cpp/subprojects/boosting/src/mlrl/boosting/statistics/quantization_stochastic.cpp +++ b/cpp/subprojects/boosting/src/mlrl/boosting/statistics/quantization_stochastic.cpp @@ -20,9 +20,12 @@ namespace boosting { void visitQuantizationMatrix( std::optional denseDecomposableMatrixVisitor, std::optional sparseDecomposableMatrixVisitor, + std::optional bitDecomposableMatrixVisitor, std::optional denseNonDecomposableMatrixVisitor) override { - // TODO Implement + if (bitDecomposableMatrixVisitor) { + (*bitDecomposableMatrixVisitor)(quantizationMatrixPtr_); + } } }; diff --git a/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_provider_decomposable_dense.cpp b/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_provider_decomposable_dense.cpp index 66e7f453b..bde03127a 100644 --- a/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_provider_decomposable_dense.cpp +++ b/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_provider_decomposable_dense.cpp @@ -40,7 +40,7 @@ namespace boosting { std::move(quantizationMatrixPtr), std::move(lossPtr), std::move(evaluationMeasurePtr), ruleEvaluationFactory, outputMatrix, std::move(statisticMatrixPtr), std::move(scoreMatrixPtr)); }; - quantizationPtr->visitQuantizationMatrix(denseDecomposableMatrixVisitor, {}, {}); + quantizationPtr->visitQuantizationMatrix(denseDecomposableMatrixVisitor, {}, {}, {}); return statisticsPtr; } diff --git a/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_provider_decomposable_sparse.cpp b/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_provider_decomposable_sparse.cpp index 0ae0ea9b8..d67fff461 100644 --- a/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_provider_decomposable_sparse.cpp +++ b/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_provider_decomposable_sparse.cpp @@ -123,7 +123,7 @@ namespace boosting { std::move(quantizationMatrixPtr), std::move(lossPtr), std::move(evaluationMeasurePtr), ruleEvaluationFactory, outputMatrix, std::move(statisticMatrixPtr), std::move(scoreMatrixPtr)); }; - quantizationPtr->visitQuantizationMatrix({}, sparseDecomposableMatrixVisitor, {}); + quantizationPtr->visitQuantizationMatrix({}, sparseDecomposableMatrixVisitor, {}, {}); return statisticsPtr; } diff --git a/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_provider_non_decomposable_dense.cpp b/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_provider_non_decomposable_dense.cpp index 086003cb0..fc22d458d 100644 --- a/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_provider_non_decomposable_dense.cpp +++ b/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_provider_non_decomposable_dense.cpp @@ -154,7 +154,7 @@ namespace boosting { std::move(this->evaluationMeasurePtr_), ruleEvaluationFactory, this->outputMatrix_, std::move(decomposableStatisticMatrixPtr), std::move(this->scoreMatrixPtr_)); }; - quantizationPtr->visitQuantizationMatrix(denseDecomposableMatrixVisitor, {}, {}); + quantizationPtr->visitQuantizationMatrix(denseDecomposableMatrixVisitor, {}, {}, {}); return statisticsPtr; } }; @@ -197,7 +197,7 @@ namespace boosting { std::move(quantizationMatrixPtr), std::move(lossPtr), std::move(evaluationMeasurePtr), ruleEvaluationFactory, outputMatrix, std::move(statisticMatrixPtr), std::move(scoreMatrixPtr)); }; - quantizationPtr->visitQuantizationMatrix({}, {}, denseNonDecomposableMatrixVisitor); + quantizationPtr->visitQuantizationMatrix({}, {}, {}, denseNonDecomposableMatrixVisitor); return statisticsPtr; }