diff --git a/cpp/subprojects/boosting/include/mlrl/boosting/statistics/quantization.hpp b/cpp/subprojects/boosting/include/mlrl/boosting/statistics/quantization.hpp index 41323a3c2..a74154b82 100644 --- a/cpp/subprojects/boosting/include/mlrl/boosting/statistics/quantization.hpp +++ b/cpp/subprojects/boosting/include/mlrl/boosting/statistics/quantization.hpp @@ -3,6 +3,7 @@ */ #pragma once +#include "mlrl/boosting/data/view_statistic_decomposable_bit.hpp" #include "mlrl/boosting/data/view_statistic_non_decomposable_dense.hpp" #include "mlrl/common/data/tuple.hpp" #include "mlrl/common/data/view_matrix_c_contiguous.hpp" @@ -43,6 +44,13 @@ namespace boosting { typedef std::function>>>&)> SparseDecomposableMatrixVisitor; + /** + * A visitor function for handling quantization matrices that are backed by a view of the type + * `BitDecomposableStatisticView`. + */ + typedef std::function>&)> + BitDecomposableMatrixVisitor; + /** * A visitor function for handling quantization matrices that are backed by a view of the type * `DenseNonDecomposableStatisticView`. @@ -57,6 +65,9 @@ namespace boosting { * @param denseDecomposableMatrixVisitor An optional visitor function for handling quantization matrices * that are backed by a view of the type * `CContiguousView>` + * @param bitDecomposableMatrixVisitor An optional visitor function for handling quantization matrices + * that are backed by a view of the type + * `BitDecomposableStatisticView` * @param sparseDecomposableMatrixVisitor An optional visitor function for handling quantization matrices * that are backed by a view of the type * `SparseSetView>` @@ -66,6 +77,7 @@ namespace boosting { */ virtual void visitQuantizationMatrix( std::optional denseDecomposableMatrixVisitor, + std::optional bitDecomposableMatrixVisitor, std::optional sparseDecomposableMatrixVisitor, std::optional denseNonDecomposableMatrixVisitor) = 0; }; diff --git a/cpp/subprojects/boosting/src/mlrl/boosting/statistics/quantization_no.cpp b/cpp/subprojects/boosting/src/mlrl/boosting/statistics/quantization_no.cpp index f7bf7eb4c..61958c6e2 100644 --- a/cpp/subprojects/boosting/src/mlrl/boosting/statistics/quantization_no.cpp +++ b/cpp/subprojects/boosting/src/mlrl/boosting/statistics/quantization_no.cpp @@ -15,6 +15,7 @@ namespace boosting { void visitQuantizationMatrix( std::optional denseDecomposableMatrixVisitor, + std::optional bitDecomposableMatrixVisitor, std::optional sparseDecomposableMatrixVisitor, std::optional denseNonDecomposableMatrixVisitor) override { @@ -37,6 +38,7 @@ namespace boosting { void visitQuantizationMatrix( std::optional denseDecomposableMatrixVisitor, + std::optional bitDecomposableMatrixVisitor, std::optional sparseDecomposableMatrixVisitor, std::optional denseNonDecomposableMatrixVisitor) override { @@ -59,6 +61,7 @@ namespace boosting { void visitQuantizationMatrix( std::optional denseDecomposableMatrixVisitor, + std::optional bitDecomposableMatrixVisitor, std::optional sparseDecomposableMatrixVisitor, std::optional denseNonDecomposableMatrixVisitor) override { diff --git a/cpp/subprojects/boosting/src/mlrl/boosting/statistics/quantization_stochastic.cpp b/cpp/subprojects/boosting/src/mlrl/boosting/statistics/quantization_stochastic.cpp index fcd1d5a4d..7c0799448 100644 --- a/cpp/subprojects/boosting/src/mlrl/boosting/statistics/quantization_stochastic.cpp +++ b/cpp/subprojects/boosting/src/mlrl/boosting/statistics/quantization_stochastic.cpp @@ -19,10 +19,13 @@ namespace boosting { void visitQuantizationMatrix( std::optional denseDecomposableMatrixVisitor, + std::optional bitDecomposableMatrixVisitor, std::optional sparseDecomposableMatrixVisitor, std::optional denseNonDecomposableMatrixVisitor) override { - // TODO Implement + if (bitDecomposableMatrixVisitor) { + (*bitDecomposableMatrixVisitor)(quantizationMatrixPtr_); + } } }; diff --git a/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_provider_decomposable_dense.cpp b/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_provider_decomposable_dense.cpp index 047d455f1..01703c484 100644 --- a/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_provider_decomposable_dense.cpp +++ b/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_provider_decomposable_dense.cpp @@ -41,7 +41,12 @@ namespace boosting { std::move(quantizationMatrixPtr), std::move(lossPtr), std::move(evaluationMeasurePtr), ruleEvaluationFactory, outputMatrix, std::move(statisticMatrixPtr), std::move(scoreMatrixPtr)); }; - quantizationPtr->visitQuantizationMatrix(denseDecomposableMatrixVisitor, {}, {}); + auto bitDecomposableMatrixVisitor = + [&](std::unique_ptr>& quantizationMatrixPtr) { + // TODO Implement + throw std::runtime_error("not implemented"); + }; + quantizationPtr->visitQuantizationMatrix(denseDecomposableMatrixVisitor, bitDecomposableMatrixVisitor, {}, {}); return statisticsPtr; } diff --git a/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_provider_decomposable_sparse.cpp b/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_provider_decomposable_sparse.cpp index 273e08c7a..491e8ba44 100644 --- a/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_provider_decomposable_sparse.cpp +++ b/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_provider_decomposable_sparse.cpp @@ -124,7 +124,7 @@ namespace boosting { std::move(quantizationMatrixPtr), std::move(lossPtr), std::move(evaluationMeasurePtr), ruleEvaluationFactory, outputMatrix, std::move(statisticMatrixPtr), std::move(scoreMatrixPtr)); }; - quantizationPtr->visitQuantizationMatrix({}, sparseDecomposableMatrixVisitor, {}); + quantizationPtr->visitQuantizationMatrix({}, {}, sparseDecomposableMatrixVisitor, {}); return statisticsPtr; } diff --git a/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_provider_non_decomposable_dense.cpp b/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_provider_non_decomposable_dense.cpp index 1ad2e3eec..14acfac5d 100644 --- a/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_provider_non_decomposable_dense.cpp +++ b/cpp/subprojects/boosting/src/mlrl/boosting/statistics/statistics_provider_non_decomposable_dense.cpp @@ -155,7 +155,13 @@ namespace boosting { std::move(this->evaluationMeasurePtr_), ruleEvaluationFactory, this->outputMatrix_, std::move(decomposableStatisticMatrixPtr), std::move(this->scoreMatrixPtr_)); }; - quantizationPtr->visitQuantizationMatrix(denseDecomposableMatrixVisitor, {}, {}); + auto bitDecomposableMatrixVisitor = + [&](std::unique_ptr>& quantizationMatrixPtr) { + // TODO Implement + throw std::runtime_error("not implemented"); + }; + quantizationPtr->visitQuantizationMatrix(denseDecomposableMatrixVisitor, bitDecomposableMatrixVisitor, + {}, {}); return statisticsPtr; } }; @@ -198,7 +204,7 @@ namespace boosting { std::move(quantizationMatrixPtr), std::move(lossPtr), std::move(evaluationMeasurePtr), ruleEvaluationFactory, outputMatrix, std::move(statisticMatrixPtr), std::move(scoreMatrixPtr)); }; - quantizationPtr->visitQuantizationMatrix({}, {}, denseNonDecomposableMatrixVisitor); + quantizationPtr->visitQuantizationMatrix({}, {}, {}, denseNonDecomposableMatrixVisitor); return statisticsPtr; }