diff --git a/cpp/subprojects/boosting/include/mlrl/boosting/data/view_statistic_decomposable_bit.hpp b/cpp/subprojects/boosting/include/mlrl/boosting/data/view_statistic_decomposable_bit.hpp new file mode 100644 index 000000000..1dfac07e1 --- /dev/null +++ b/cpp/subprojects/boosting/include/mlrl/boosting/data/view_statistic_decomposable_bit.hpp @@ -0,0 +1,34 @@ +/* + * @author Michael Rapp (michael.rapp.ml@gmail.com) + */ +#pragma once + +#include "mlrl/common/data/view_matrix_bit.hpp" +#include "mlrl/common/data/view_matrix_composite.hpp" + +namespace boosting { + + /** + * Implements row-wise read and write access to the gradients and Hessians that have been calculated using a + * decomposable loss function and are stored in pre-allocated bit matrices. + */ + class BitDecomposableStatisticView + : public CompositeMatrix, AllocatedBitMatrix> { + public: + + /** + * @param numRows The number of rows in the view + * @param numCols The number of columns in the view + * @param numBits The number of bits per statistic in the view + */ + BitDecomposableStatisticView(uint32 numRows, uint32 numCols, uint32 numBits); + + /** + * @param other A reference to an object of type `BitDecomposableStatisticView` that should be copied + */ + BitDecomposableStatisticView(BitDecomposableStatisticView&& other); + + virtual ~BitDecomposableStatisticView() override {} + }; + +} diff --git a/cpp/subprojects/boosting/meson.build b/cpp/subprojects/boosting/meson.build index 6752a9ecc..4b878820e 100644 --- a/cpp/subprojects/boosting/meson.build +++ b/cpp/subprojects/boosting/meson.build @@ -10,6 +10,7 @@ source_files = [ 'src/mlrl/boosting/data/vector_statistic_decomposable_dense.cpp', 'src/mlrl/boosting/data/vector_statistic_decomposable_sparse.cpp', 'src/mlrl/boosting/data/vector_statistic_non_decomposable_dense.cpp', + 'src/mlrl/boosting/data/view_statistic_decomposable_bit.cpp', 'src/mlrl/boosting/data/view_statistic_non_decomposable_dense.cpp', 'src/mlrl/boosting/input/feature_binning_auto.cpp', 'src/mlrl/boosting/losses/loss_decomposable_logistic.cpp', diff --git a/cpp/subprojects/boosting/src/mlrl/boosting/data/view_statistic_decomposable_bit.cpp b/cpp/subprojects/boosting/src/mlrl/boosting/data/view_statistic_decomposable_bit.cpp new file mode 100644 index 000000000..008606bc6 --- /dev/null +++ b/cpp/subprojects/boosting/src/mlrl/boosting/data/view_statistic_decomposable_bit.cpp @@ -0,0 +1,15 @@ +#include "mlrl/boosting/data/view_statistic_decomposable_bit.hpp" + +#include "mlrl/boosting/util/math.hpp" + +namespace boosting { + + BitDecomposableStatisticView::BitDecomposableStatisticView(uint32 numRows, uint32 numCols, uint32 numBits) + : CompositeMatrix, AllocatedBitMatrix>( + AllocatedBitMatrix(numRows, numCols, numBits), + AllocatedBitMatrix(numRows, util::triangularNumber(numCols), numBits), numRows, numCols) {} + + BitDecomposableStatisticView::BitDecomposableStatisticView(BitDecomposableStatisticView&& other) + : CompositeMatrix, AllocatedBitMatrix>(std::move(other)) {} + +} diff --git a/cpp/subprojects/boosting/src/mlrl/boosting/statistics/quantization_stochastic.cpp b/cpp/subprojects/boosting/src/mlrl/boosting/statistics/quantization_stochastic.cpp index 975a551dc..fcd1d5a4d 100644 --- a/cpp/subprojects/boosting/src/mlrl/boosting/statistics/quantization_stochastic.cpp +++ b/cpp/subprojects/boosting/src/mlrl/boosting/statistics/quantization_stochastic.cpp @@ -1,6 +1,6 @@ #include "mlrl/boosting/statistics/quantization_stochastic.hpp" -#include "mlrl/common/data/matrix_bit_integer.hpp" +#include "mlrl/boosting/data/view_statistic_decomposable_bit.hpp" #include "mlrl/common/util/validation.hpp" namespace boosting { @@ -9,11 +9,12 @@ namespace boosting { class StochasticQuantization final : public IQuantization { private: - std::unique_ptr>> quantizationMatrixPtr_; + std::unique_ptr> quantizationMatrixPtr_; public: - StochasticQuantization(std::unique_ptr>> quantizationMatrixPtr) + StochasticQuantization( + std::unique_ptr> quantizationMatrixPtr) : quantizationMatrixPtr_(std::move(quantizationMatrixPtr)) {} void visitQuantizationMatrix( @@ -26,17 +27,17 @@ namespace boosting { }; template - class StochasticQuantizationMatrix final : public IQuantizationMatrix> { + class StochasticQuantizationMatrix final : public IQuantizationMatrix { private: const View& view_; - IntegerBitMatrix matrix_; + BitDecomposableStatisticView quantizedView_; public: StochasticQuantizationMatrix(const View& view, uint32 numBits) - : view_(view), matrix_(view.numRows, view.numCols, numBits, true) {} + : view_(view), quantizedView_(view.numRows, view.numCols, numBits) {} void quantize(const CompleteIndexVector& outputIndices) override { // TODO Implement @@ -46,28 +47,28 @@ namespace boosting { // TODO Implement } - const typename IQuantizationMatrix>::view_type& getView() const override { - return matrix_.getView(); + const typename IQuantizationMatrix::view_type& getView() const override { + return quantizedView_; } std::unique_ptr create( const CContiguousView>& statisticMatrix) const override { return std::make_unique>>>( std::make_unique>>>( - statisticMatrix, matrix_.getView().numBitsPerElement)); + statisticMatrix, quantizedView_.firstView.numBitsPerElement)); } std::unique_ptr create(const SparseSetView>& statisticMatrix) const override { return std::make_unique>>>( std::make_unique>>>( - statisticMatrix, matrix_.getView().numBitsPerElement)); + statisticMatrix, quantizedView_.firstView.numBitsPerElement)); } std::unique_ptr create( const DenseNonDecomposableStatisticView& statisticMatrix) const override { return std::make_unique>( std::make_unique>( - statisticMatrix, matrix_.getView().numBitsPerElement)); + statisticMatrix, quantizedView_.firstView.numBitsPerElement)); } }; diff --git a/cpp/subprojects/common/include/mlrl/common/data/matrix_bit_integer.hpp b/cpp/subprojects/common/include/mlrl/common/data/matrix_bit_integer.hpp deleted file mode 100644 index 2b540bf34..000000000 --- a/cpp/subprojects/common/include/mlrl/common/data/matrix_bit_integer.hpp +++ /dev/null @@ -1,23 +0,0 @@ -/* - * @author Michael Rapp (michael.rapp.ml@gmail.com) - */ -#pragma once - -#include "mlrl/common/data/view_matrix_bit.hpp" - -/** - * A two-dimensional matrix that stores integer values, each with a specific number of bits. - */ -class IntegerBitMatrix final : public MatrixDecorator> { - public: - - /** - * @param numRows The number of rows in the matrix - * @param numCols The number of columns in the matrix - * @param numBitsPerElement The number of bits per element in the matrix - * @param init True, if all elements in the matrix should be value-initialized, false otherwise - */ - IntegerBitMatrix(uint32 numRows, uint32 numCols, uint32 numBitsPerElements, bool init = false) - : MatrixDecorator>( - AllocatedBitMatrix(numRows, numCols, numBitsPerElements, init)) {} -};