Skip to content

Commit

Permalink
Add template argument StatisticVector to class DenseDecomposableStati…
Browse files Browse the repository at this point in the history
…stics.
  • Loading branch information
michael-rapp committed Sep 13, 2024
1 parent 23c1241 commit d2bad1b
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,13 @@ namespace boosting {
* @tparam OutputMatrix The type of the matrix that provides access to the ground truth of the training
* examples
* @tparam QuantizationMatrix The type of the matrix that provides access to quantized gradients and Hessians
* @tparam StatisticVector The type of the vectors that are used to store gradients and Hessians
* @tparam EvaluationMeasure The type of the evaluation that should be used to access the quality of predictions
*/
template<typename Loss, typename OutputMatrix, typename QuantizationMatrix, typename EvaluationMeasure>
template<typename Loss, typename OutputMatrix, typename QuantizationMatrix, typename StatisticVector,
typename EvaluationMeasure>
class DenseDecomposableStatistics final
: public AbstractDecomposableStatistics<OutputMatrix, QuantizationMatrix, DenseDecomposableStatisticVector,
: public AbstractDecomposableStatistics<OutputMatrix, QuantizationMatrix, StatisticVector,
DenseDecomposableStatisticMatrix, NumericCContiguousMatrix<float64>,
Loss, EvaluationMeasure, IDecomposableRuleEvaluationFactory> {
public:
Expand Down Expand Up @@ -87,7 +89,7 @@ namespace boosting {
const OutputMatrix& outputMatrix,
std::unique_ptr<DenseDecomposableStatisticMatrix> statisticMatrixPtr,
std::unique_ptr<NumericCContiguousMatrix<float64>> scoreMatrixPtr)
: AbstractDecomposableStatistics<OutputMatrix, QuantizationMatrix, DenseDecomposableStatisticVector,
: AbstractDecomposableStatistics<OutputMatrix, QuantizationMatrix, StatisticVector,
DenseDecomposableStatisticMatrix, NumericCContiguousMatrix<float64>,
Loss, EvaluationMeasure, IDecomposableRuleEvaluationFactory>(
std::move(quantizationMatrixPtr), std::move(lossPtr), std::move(evaluationMeasurePtr),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ namespace boosting {
std::unique_ptr<IDecomposableStatistics<IDecomposableRuleEvaluationFactory>> statisticsPtr;
auto denseDecomposableMatrixVisitor =
[&](std::unique_ptr<IQuantizationMatrix<CContiguousView<Tuple<float64>>>>& quantizationMatrixPtr) {
statisticsPtr = std::make_unique<DenseDecomposableStatistics<
Loss, OutputMatrix, IQuantizationMatrix<CContiguousView<Tuple<float64>>>, EvaluationMeasure>>(
statisticsPtr = std::make_unique<
DenseDecomposableStatistics<Loss, OutputMatrix, IQuantizationMatrix<CContiguousView<Tuple<float64>>>,
DenseDecomposableStatisticVector, EvaluationMeasure>>(
std::move(quantizationMatrixPtr), std::move(lossPtr), std::move(evaluationMeasurePtr),
ruleEvaluationFactory, outputMatrix, std::move(statisticMatrixPtr), std::move(scoreMatrixPtr));
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ namespace boosting {
auto denseDecomposableMatrixVisitor =
[&](std::unique_ptr<IQuantizationMatrix<CContiguousView<Tuple<float64>>>>& quantizationMatrixPtr) {
statisticsPtr = std::make_unique<DenseDecomposableStatistics<
Loss, OutputMatrix, IQuantizationMatrix<CContiguousView<Tuple<float64>>>, EvaluationMeasure>>(
Loss, OutputMatrix, IQuantizationMatrix<CContiguousView<Tuple<float64>>>,
DenseDecomposableStatisticVector, EvaluationMeasure>>(
std::move(quantizationMatrixPtr), std::move(this->lossPtr_),
std::move(this->evaluationMeasurePtr_), ruleEvaluationFactory, this->outputMatrix_,
std::move(decomposableStatisticMatrixPtr), std::move(this->scoreMatrixPtr_));
Expand Down

0 comments on commit d2bad1b

Please sign in to comment.