Skip to content

Commit

Permalink
Add visitor for handling views of type BitDecomposableMatrixVisitor t…
Browse files Browse the repository at this point in the history
…o class IQuantization.
  • Loading branch information
michael-rapp committed Oct 2, 2024
1 parent a1c2c5c commit 457e7db
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
*/
#pragma once

#include "mlrl/boosting/data/view_statistic_decomposable_bit.hpp"
#include "mlrl/boosting/data/view_statistic_non_decomposable_dense.hpp"
#include "mlrl/common/data/tuple.hpp"
#include "mlrl/common/data/view_matrix_c_contiguous.hpp"
Expand Down Expand Up @@ -43,6 +44,13 @@ namespace boosting {
typedef std::function<void(std::unique_ptr<IQuantizationMatrix<SparseSetView<Tuple<float64>>>>&)>
SparseDecomposableMatrixVisitor;

/**
* A visitor function for handling quantization matrices that are backed by a view of the type
* `BitDecomposableStatisticView`.
*/
typedef std::function<void(std::unique_ptr<IQuantizationMatrix<BitDecomposableStatisticView>>&)>
BitDecomposableMatrixVisitor;

/**
* A visitor function for handling quantization matrices that are backed by a view of the type
* `DenseNonDecomposableStatisticView`.
Expand All @@ -57,6 +65,9 @@ namespace boosting {
* @param denseDecomposableMatrixVisitor An optional visitor function for handling quantization matrices
* that are backed by a view of the type
* `CContiguousView<Tuple<float64>>`
* @param bitDecomposableMatrixVisitor An optional visitor function for handling quantization matrices
* that are backed by a view of the type
* `BitDecomposableStatisticView`
* @param sparseDecomposableMatrixVisitor An optional visitor function for handling quantization matrices
* that are backed by a view of the type
* `SparseSetView<Tuple<float64>>`
Expand All @@ -66,6 +77,7 @@ namespace boosting {
*/
virtual void visitQuantizationMatrix(
std::optional<DenseDecomposableMatrixVisitor> denseDecomposableMatrixVisitor,
std::optional<BitDecomposableMatrixVisitor> bitDecomposableMatrixVisitor,
std::optional<SparseDecomposableMatrixVisitor> sparseDecomposableMatrixVisitor,
std::optional<DenseNonDecomposableMatrixVisitor> denseNonDecomposableMatrixVisitor) = 0;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ namespace boosting {

void visitQuantizationMatrix(
std::optional<IQuantization::DenseDecomposableMatrixVisitor> denseDecomposableMatrixVisitor,
std::optional<BitDecomposableMatrixVisitor> bitDecomposableMatrixVisitor,
std::optional<IQuantization::SparseDecomposableMatrixVisitor> sparseDecomposableMatrixVisitor,
std::optional<IQuantization::DenseNonDecomposableMatrixVisitor> denseNonDecomposableMatrixVisitor)
override {
Expand All @@ -37,6 +38,7 @@ namespace boosting {

void visitQuantizationMatrix(
std::optional<IQuantization::DenseDecomposableMatrixVisitor> denseDecomposableMatrixVisitor,
std::optional<BitDecomposableMatrixVisitor> bitDecomposableMatrixVisitor,
std::optional<IQuantization::SparseDecomposableMatrixVisitor> sparseDecomposableMatrixVisitor,
std::optional<IQuantization::DenseNonDecomposableMatrixVisitor> denseNonDecomposableMatrixVisitor)
override {
Expand All @@ -59,6 +61,7 @@ namespace boosting {

void visitQuantizationMatrix(
std::optional<IQuantization::DenseDecomposableMatrixVisitor> denseDecomposableMatrixVisitor,
std::optional<BitDecomposableMatrixVisitor> bitDecomposableMatrixVisitor,
std::optional<IQuantization::SparseDecomposableMatrixVisitor> sparseDecomposableMatrixVisitor,
std::optional<IQuantization::DenseNonDecomposableMatrixVisitor> denseNonDecomposableMatrixVisitor)
override {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@ namespace boosting {

void visitQuantizationMatrix(
std::optional<IQuantization::DenseDecomposableMatrixVisitor> denseDecomposableMatrixVisitor,
std::optional<BitDecomposableMatrixVisitor> bitDecomposableMatrixVisitor,
std::optional<IQuantization::SparseDecomposableMatrixVisitor> sparseDecomposableMatrixVisitor,
std::optional<IQuantization::DenseNonDecomposableMatrixVisitor> denseNonDecomposableMatrixVisitor)
override {
// TODO Implement
if (bitDecomposableMatrixVisitor) {
(*bitDecomposableMatrixVisitor)(quantizationMatrixPtr_);
}
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ namespace boosting {
std::move(quantizationMatrixPtr), std::move(lossPtr), std::move(evaluationMeasurePtr),
ruleEvaluationFactory, outputMatrix, std::move(statisticMatrixPtr), std::move(scoreMatrixPtr));
};
quantizationPtr->visitQuantizationMatrix(denseDecomposableMatrixVisitor, {}, {});
auto bitDecomposableMatrixVisitor =
[&](std::unique_ptr<IQuantizationMatrix<BitDecomposableStatisticView>>& quantizationMatrixPtr) {
// TODO Implement
throw std::runtime_error("not implemented");
};
quantizationPtr->visitQuantizationMatrix(denseDecomposableMatrixVisitor, bitDecomposableMatrixVisitor, {}, {});
return statisticsPtr;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ namespace boosting {
std::move(quantizationMatrixPtr), std::move(lossPtr), std::move(evaluationMeasurePtr),
ruleEvaluationFactory, outputMatrix, std::move(statisticMatrixPtr), std::move(scoreMatrixPtr));
};
quantizationPtr->visitQuantizationMatrix({}, sparseDecomposableMatrixVisitor, {});
quantizationPtr->visitQuantizationMatrix({}, {}, sparseDecomposableMatrixVisitor, {});
return statisticsPtr;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,13 @@ namespace boosting {
std::move(this->evaluationMeasurePtr_), ruleEvaluationFactory, this->outputMatrix_,
std::move(decomposableStatisticMatrixPtr), std::move(this->scoreMatrixPtr_));
};
quantizationPtr->visitQuantizationMatrix(denseDecomposableMatrixVisitor, {}, {});
auto bitDecomposableMatrixVisitor =
[&](std::unique_ptr<IQuantizationMatrix<BitDecomposableStatisticView>>& quantizationMatrixPtr) {
// TODO Implement
throw std::runtime_error("not implemented");
};
quantizationPtr->visitQuantizationMatrix(denseDecomposableMatrixVisitor, bitDecomposableMatrixVisitor,
{}, {});
return statisticsPtr;
}
};
Expand Down Expand Up @@ -198,7 +204,7 @@ namespace boosting {
std::move(quantizationMatrixPtr), std::move(lossPtr), std::move(evaluationMeasurePtr),
ruleEvaluationFactory, outputMatrix, std::move(statisticMatrixPtr), std::move(scoreMatrixPtr));
};
quantizationPtr->visitQuantizationMatrix({}, {}, denseNonDecomposableMatrixVisitor);
quantizationPtr->visitQuantizationMatrix({}, {}, {}, denseNonDecomposableMatrixVisitor);
return statisticsPtr;
}

Expand Down

0 comments on commit 457e7db

Please sign in to comment.