Skip to content

Commit

Permalink
Replace class IntegerBitMatrix with new class BitDecomposableStatisti…
Browse files Browse the repository at this point in the history
…cView.
  • Loading branch information
michael-rapp committed Oct 2, 2024
1 parent 7e97408 commit cb7d5d8
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 34 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* @author Michael Rapp ([email protected])
*/
#pragma once

#include "mlrl/common/data/view_matrix_bit.hpp"
#include "mlrl/common/data/view_matrix_composite.hpp"

namespace boosting {

/**
* Implements row-wise read and write access to the gradients and Hessians that have been calculated using a
* decomposable loss function and are stored in pre-allocated bit matrices.
*/
class BitDecomposableStatisticView
: public CompositeMatrix<AllocatedBitMatrix<uint32>, AllocatedBitMatrix<uint32>> {
public:

/**
* @param numRows The number of rows in the view
* @param numCols The number of columns in the view
* @param numBits The number of bits per statistic in the view
*/
BitDecomposableStatisticView(uint32 numRows, uint32 numCols, uint32 numBits);

/**
* @param other A reference to an object of type `BitDecomposableStatisticView` that should be copied
*/
BitDecomposableStatisticView(BitDecomposableStatisticView&& other);

virtual ~BitDecomposableStatisticView() override {}
};

}
1 change: 1 addition & 0 deletions cpp/subprojects/boosting/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ source_files = [
'src/mlrl/boosting/data/vector_statistic_decomposable_dense.cpp',
'src/mlrl/boosting/data/vector_statistic_decomposable_sparse.cpp',
'src/mlrl/boosting/data/vector_statistic_non_decomposable_dense.cpp',
'src/mlrl/boosting/data/view_statistic_decomposable_bit.cpp',
'src/mlrl/boosting/data/view_statistic_non_decomposable_dense.cpp',
'src/mlrl/boosting/input/feature_binning_auto.cpp',
'src/mlrl/boosting/losses/loss_decomposable_logistic.cpp',
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#include "mlrl/boosting/data/view_statistic_decomposable_bit.hpp"

#include "mlrl/boosting/util/math.hpp"

namespace boosting {

BitDecomposableStatisticView::BitDecomposableStatisticView(uint32 numRows, uint32 numCols, uint32 numBits)
: CompositeMatrix<AllocatedBitMatrix<uint32>, AllocatedBitMatrix<uint32>>(
AllocatedBitMatrix<uint32>(numRows, numCols, numBits),
AllocatedBitMatrix<uint32>(numRows, util::triangularNumber(numCols), numBits), numRows, numCols) {}

BitDecomposableStatisticView::BitDecomposableStatisticView(BitDecomposableStatisticView&& other)
: CompositeMatrix<AllocatedBitMatrix<uint32>, AllocatedBitMatrix<uint32>>(std::move(other)) {}

}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "mlrl/boosting/statistics/quantization_stochastic.hpp"

#include "mlrl/common/data/matrix_bit_integer.hpp"
#include "mlrl/boosting/data/view_statistic_decomposable_bit.hpp"
#include "mlrl/common/util/validation.hpp"

namespace boosting {
Expand All @@ -9,11 +9,12 @@ namespace boosting {
class StochasticQuantization final : public IQuantization {
private:

std::unique_ptr<IQuantizationMatrix<BitMatrix<uint32>>> quantizationMatrixPtr_;
std::unique_ptr<IQuantizationMatrix<BitDecomposableStatisticView>> quantizationMatrixPtr_;

public:

StochasticQuantization(std::unique_ptr<IQuantizationMatrix<BitMatrix<uint32>>> quantizationMatrixPtr)
StochasticQuantization(
std::unique_ptr<IQuantizationMatrix<BitDecomposableStatisticView>> quantizationMatrixPtr)
: quantizationMatrixPtr_(std::move(quantizationMatrixPtr)) {}

void visitQuantizationMatrix(
Expand All @@ -26,17 +27,17 @@ namespace boosting {
};

template<typename View>
class StochasticQuantizationMatrix final : public IQuantizationMatrix<BitMatrix<uint32>> {
class StochasticQuantizationMatrix final : public IQuantizationMatrix<BitDecomposableStatisticView> {
private:

const View& view_;

IntegerBitMatrix matrix_;
BitDecomposableStatisticView quantizedView_;

public:

StochasticQuantizationMatrix(const View& view, uint32 numBits)
: view_(view), matrix_(view.numRows, view.numCols, numBits, true) {}
: view_(view), quantizedView_(view.numRows, view.numCols, numBits) {}

void quantize(const CompleteIndexVector& outputIndices) override {
// TODO Implement
Expand All @@ -46,28 +47,28 @@ namespace boosting {
// TODO Implement
}

const typename IQuantizationMatrix<BitMatrix<uint32>>::view_type& getView() const override {
return matrix_.getView();
const typename IQuantizationMatrix<BitDecomposableStatisticView>::view_type& getView() const override {
return quantizedView_;
}

std::unique_ptr<IQuantization> create(
const CContiguousView<Tuple<float64>>& statisticMatrix) const override {
return std::make_unique<StochasticQuantization<CContiguousView<Tuple<float64>>>>(
std::make_unique<StochasticQuantizationMatrix<CContiguousView<Tuple<float64>>>>(
statisticMatrix, matrix_.getView().numBitsPerElement));
statisticMatrix, quantizedView_.firstView.numBitsPerElement));
}

std::unique_ptr<IQuantization> create(const SparseSetView<Tuple<float64>>& statisticMatrix) const override {
return std::make_unique<StochasticQuantization<SparseSetView<Tuple<float64>>>>(
std::make_unique<StochasticQuantizationMatrix<SparseSetView<Tuple<float64>>>>(
statisticMatrix, matrix_.getView().numBitsPerElement));
statisticMatrix, quantizedView_.firstView.numBitsPerElement));
}

std::unique_ptr<IQuantization> create(
const DenseNonDecomposableStatisticView& statisticMatrix) const override {
return std::make_unique<StochasticQuantization<DenseNonDecomposableStatisticView>>(
std::make_unique<StochasticQuantizationMatrix<DenseNonDecomposableStatisticView>>(
statisticMatrix, matrix_.getView().numBitsPerElement));
statisticMatrix, quantizedView_.firstView.numBitsPerElement));
}
};

Expand Down

This file was deleted.

0 comments on commit cb7d5d8

Please sign in to comment.