Skip to content

Commit

Permalink
Add visitor for handling views of type BitDecomposableMatrixVisitor t…
Browse files Browse the repository at this point in the history
…o class IQuantization.
  • Loading branch information
michael-rapp committed Aug 19, 2024
1 parent b407be6 commit 0d152d1
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
*/
#pragma once

#include "mlrl/boosting/data/view_statistic_decomposable_bit.hpp"
#include "mlrl/boosting/data/view_statistic_non_decomposable_dense.hpp"
#include "mlrl/common/data/tuple.hpp"
#include "mlrl/common/data/view_matrix_c_contiguous.hpp"
Expand Down Expand Up @@ -43,6 +44,13 @@ namespace boosting {
typedef std::function<void(std::unique_ptr<IQuantizationMatrix<SparseSetView<Tuple<float64>>>>&)>
SparseDecomposableMatrixVisitor;

/**
* A visitor function for handling quantization matrices that are backed by a view of the type
* `BitDecomposableStatisticView`.
*/
typedef std::function<void(std::unique_ptr<IQuantizationMatrix<BitDecomposableStatisticView>>&)>
BitDecomposableMatrixVisitor;

/**
* A visitor function for handling quantization matrices that are backed by a view of the type
* `DenseNonDecomposableStatisticView`.
Expand All @@ -60,13 +68,17 @@ namespace boosting {
* @param sparseDecomposableMatrixVisitor An optional visitor function for handling quantization matrices
* that are backed by a view of the type
* `SparseSetView<Tuple<float64>>`
* @param bitDecomposableMatrixVisitor An optional visitor function for handling quantization matrices
* that are backed by a view of the type
* `BitDecomposableStatisticView`
* @param denseNonDecomposableMatrixVisitor An optional visitor function for handling quantization matrices
* that are backed by a view of the type
* `DenseNonDecomposableStatisticView`
*/
virtual void visitQuantizationMatrix(
std::optional<DenseDecomposableMatrixVisitor> denseDecomposableMatrixVisitor,
std::optional<SparseDecomposableMatrixVisitor> sparseDecomposableMatrixVisitor,
std::optional<BitDecomposableMatrixVisitor> bitDecomposableMatrixVisitor,
std::optional<DenseNonDecomposableMatrixVisitor> denseNonDecomposableMatrixVisitor) = 0;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace boosting {
void visitQuantizationMatrix(
std::optional<IQuantization::DenseDecomposableMatrixVisitor> denseDecomposableMatrixVisitor,
std::optional<IQuantization::SparseDecomposableMatrixVisitor> sparseDecomposableMatrixVisitor,
std::optional<BitDecomposableMatrixVisitor> bitDecomposableMatrixVisitor,
std::optional<IQuantization::DenseNonDecomposableMatrixVisitor> denseNonDecomposableMatrixVisitor)
override {
if (denseDecomposableMatrixVisitor) {
Expand All @@ -38,6 +39,7 @@ namespace boosting {
void visitQuantizationMatrix(
std::optional<IQuantization::DenseDecomposableMatrixVisitor> denseDecomposableMatrixVisitor,
std::optional<IQuantization::SparseDecomposableMatrixVisitor> sparseDecomposableMatrixVisitor,
std::optional<BitDecomposableMatrixVisitor> bitDecomposableMatrixVisitor,
std::optional<IQuantization::DenseNonDecomposableMatrixVisitor> denseNonDecomposableMatrixVisitor)
override {
if (sparseDecomposableMatrixVisitor) {
Expand All @@ -60,6 +62,7 @@ namespace boosting {
void visitQuantizationMatrix(
std::optional<IQuantization::DenseDecomposableMatrixVisitor> denseDecomposableMatrixVisitor,
std::optional<IQuantization::SparseDecomposableMatrixVisitor> sparseDecomposableMatrixVisitor,
std::optional<BitDecomposableMatrixVisitor> bitDecomposableMatrixVisitor,
std::optional<IQuantization::DenseNonDecomposableMatrixVisitor> denseNonDecomposableMatrixVisitor)
override {
if (denseNonDecomposableMatrixVisitor) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@ namespace boosting {
void visitQuantizationMatrix(
std::optional<IQuantization::DenseDecomposableMatrixVisitor> denseDecomposableMatrixVisitor,
std::optional<IQuantization::SparseDecomposableMatrixVisitor> sparseDecomposableMatrixVisitor,
std::optional<BitDecomposableMatrixVisitor> bitDecomposableMatrixVisitor,
std::optional<IQuantization::DenseNonDecomposableMatrixVisitor> denseNonDecomposableMatrixVisitor)
override {
// TODO Implement
if (bitDecomposableMatrixVisitor) {
(*bitDecomposableMatrixVisitor)(quantizationMatrixPtr_);
}
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ namespace boosting {
std::move(quantizationMatrixPtr), std::move(lossPtr), std::move(evaluationMeasurePtr),
ruleEvaluationFactory, outputMatrix, std::move(statisticMatrixPtr), std::move(scoreMatrixPtr));
};
quantizationPtr->visitQuantizationMatrix(denseDecomposableMatrixVisitor, {}, {});
quantizationPtr->visitQuantizationMatrix(denseDecomposableMatrixVisitor, {}, {}, {});
return statisticsPtr;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ namespace boosting {
std::move(quantizationMatrixPtr), std::move(lossPtr), std::move(evaluationMeasurePtr),
ruleEvaluationFactory, outputMatrix, std::move(statisticMatrixPtr), std::move(scoreMatrixPtr));
};
quantizationPtr->visitQuantizationMatrix({}, sparseDecomposableMatrixVisitor, {});
quantizationPtr->visitQuantizationMatrix({}, sparseDecomposableMatrixVisitor, {}, {});
return statisticsPtr;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ namespace boosting {
std::move(this->evaluationMeasurePtr_), ruleEvaluationFactory, this->outputMatrix_,
std::move(decomposableStatisticMatrixPtr), std::move(this->scoreMatrixPtr_));
};
quantizationPtr->visitQuantizationMatrix(denseDecomposableMatrixVisitor, {}, {});
quantizationPtr->visitQuantizationMatrix(denseDecomposableMatrixVisitor, {}, {}, {});
return statisticsPtr;
}
};
Expand Down Expand Up @@ -197,7 +197,7 @@ namespace boosting {
std::move(quantizationMatrixPtr), std::move(lossPtr), std::move(evaluationMeasurePtr),
ruleEvaluationFactory, outputMatrix, std::move(statisticMatrixPtr), std::move(scoreMatrixPtr));
};
quantizationPtr->visitQuantizationMatrix({}, {}, denseNonDecomposableMatrixVisitor);
quantizationPtr->visitQuantizationMatrix({}, {}, {}, denseNonDecomposableMatrixVisitor);
return statisticsPtr;
}

Expand Down

0 comments on commit 0d152d1

Please sign in to comment.