Skip to content

Commit

Permalink
Add create functions that accept arguments of type BitDecomposableSta…
Browse files Browse the repository at this point in the history
…tisticVector to the class IDecomposableRuleEvaluationFactory.
  • Loading branch information
michael-rapp committed Aug 30, 2024
1 parent d7fb124 commit 7eb512b
Show file tree
Hide file tree
Showing 15 changed files with 183 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
*/
#pragma once

#include "mlrl/boosting/data/vector_statistic_decomposable_bit.hpp"
#include "mlrl/boosting/data/vector_statistic_decomposable_dense.hpp"
#include "mlrl/boosting/rule_evaluation/rule_evaluation.hpp"
#include "mlrl/common/indices/index_vector_complete.hpp"
Expand Down Expand Up @@ -52,6 +53,36 @@ namespace boosting {
*/
virtual std::unique_ptr<IRuleEvaluation<DenseDecomposableStatisticVector>> create(
const DenseDecomposableStatisticVector& statisticVector, const PartialIndexVector& indexVector) const = 0;

/**
* Creates a new instance of the class `IRuleEvaluation` that allows to calculate the predictions of rules
* that predict for all available outputs, based on the gradients and Hessians that are stored by a
* `BitDecomposableStatisticVector`.
*
* @param statisticVector A reference to an object of type `BitDecomposableStatisticVector`. This vector
* is only used to identify the function that is able to deal with this particular
* type of vector via function overloading
* @param indexVector A reference to an object of the type `CompleteIndexVector` that provides access
* to the indices of the outputs for which the rules may predict
* @return An unique pointer to an object of type `IRuleEvaluation` that has been created
*/
virtual std::unique_ptr<IRuleEvaluation<BitDecomposableStatisticVector>> create(
const BitDecomposableStatisticVector& statisticVector, const CompleteIndexVector& indexVector) const = 0;

/**
* Creates a new instance of the class `IRuleEvaluation` that allows to calculate the predictions of rules
* that predict for a subset of the available outputs, based on the gradients and Hessians that are stored
* by a `BitDecomposableStatisticVector`.
*
* @param statisticVector A reference to an object of type `BitDecomposableStatisticVector`. This vector
* is only used to identify the function that is able to deal with this particular
* type of vector via function overloading
* @param indexVector A reference to an object of the type `PartialIndexVector` that provides access
* to the indices of the outputs for which the rules may predict
* @return An unique pointer to an object of type `IRuleEvaluation` that has been created
*/
virtual std::unique_ptr<IRuleEvaluation<BitDecomposableStatisticVector>> create(
const BitDecomposableStatisticVector& statisticVector, const PartialIndexVector& indexVector) const = 0;
};

}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ namespace boosting {
std::unique_ptr<IRuleEvaluation<DenseDecomposableStatisticVector>> create(
const DenseDecomposableStatisticVector& statisticVector,
const PartialIndexVector& indexVector) const override;

std::unique_ptr<IRuleEvaluation<BitDecomposableStatisticVector>> create(
const BitDecomposableStatisticVector& statisticVector,
const CompleteIndexVector& indexVector) const override;

std::unique_ptr<IRuleEvaluation<BitDecomposableStatisticVector>> create(
const BitDecomposableStatisticVector& statisticVector,
const PartialIndexVector& indexVector) const override;
};

}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ namespace boosting {
std::unique_ptr<IRuleEvaluation<DenseDecomposableStatisticVector>> create(
const DenseDecomposableStatisticVector& statisticVector,
const PartialIndexVector& indexVector) const override;

std::unique_ptr<IRuleEvaluation<BitDecomposableStatisticVector>> create(
const BitDecomposableStatisticVector& statisticVector,
const CompleteIndexVector& indexVector) const override;

std::unique_ptr<IRuleEvaluation<BitDecomposableStatisticVector>> create(
const BitDecomposableStatisticVector& statisticVector,
const PartialIndexVector& indexVector) const override;
};

}
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ namespace boosting {
const DenseDecomposableStatisticVector& statisticVector,
const PartialIndexVector& indexVector) const override;

std::unique_ptr<IRuleEvaluation<BitDecomposableStatisticVector>> create(
const BitDecomposableStatisticVector& statisticVector,
const CompleteIndexVector& indexVector) const override;

std::unique_ptr<IRuleEvaluation<BitDecomposableStatisticVector>> create(
const BitDecomposableStatisticVector& statisticVector,
const PartialIndexVector& indexVector) const override;

std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector>> create(
const SparseDecomposableStatisticVector& statisticVector,
const CompleteIndexVector& indexVector) const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ namespace boosting {
const DenseDecomposableStatisticVector& statisticVector,
const PartialIndexVector& indexVector) const override;

std::unique_ptr<IRuleEvaluation<BitDecomposableStatisticVector>> create(
const BitDecomposableStatisticVector& statisticVector,
const CompleteIndexVector& indexVector) const override;

std::unique_ptr<IRuleEvaluation<BitDecomposableStatisticVector>> create(
const BitDecomposableStatisticVector& statisticVector,
const PartialIndexVector& indexVector) const override;

std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector>> create(
const SparseDecomposableStatisticVector& statisticVector,
const CompleteIndexVector& indexVector) const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ namespace boosting {
const DenseDecomposableStatisticVector& statisticVector,
const PartialIndexVector& indexVector) const override;

std::unique_ptr<IRuleEvaluation<BitDecomposableStatisticVector>> create(
const BitDecomposableStatisticVector& statisticVector,
const CompleteIndexVector& indexVector) const override;

std::unique_ptr<IRuleEvaluation<BitDecomposableStatisticVector>> create(
const BitDecomposableStatisticVector& statisticVector,
const PartialIndexVector& indexVector) const override;

std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector>> create(
const SparseDecomposableStatisticVector& statisticVector,
const CompleteIndexVector& indexVector) const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ namespace boosting {
const DenseDecomposableStatisticVector& statisticVector,
const PartialIndexVector& indexVector) const override;

std::unique_ptr<IRuleEvaluation<BitDecomposableStatisticVector>> create(
const BitDecomposableStatisticVector& statisticVector,
const CompleteIndexVector& indexVector) const override;

std::unique_ptr<IRuleEvaluation<BitDecomposableStatisticVector>> create(
const BitDecomposableStatisticVector& statisticVector,
const PartialIndexVector& indexVector) const override;

std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector>> create(
const SparseDecomposableStatisticVector& statisticVector,
const CompleteIndexVector& indexVector) const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ namespace boosting {
const DenseDecomposableStatisticVector& statisticVector,
const PartialIndexVector& indexVector) const override;

std::unique_ptr<IRuleEvaluation<BitDecomposableStatisticVector>> create(
const BitDecomposableStatisticVector& statisticVector,
const CompleteIndexVector& indexVector) const override;

std::unique_ptr<IRuleEvaluation<BitDecomposableStatisticVector>> create(
const BitDecomposableStatisticVector& statisticVector,
const PartialIndexVector& indexVector) const override;

std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector>> create(
const SparseDecomposableStatisticVector& statisticVector,
const CompleteIndexVector& indexVector) const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,16 @@ namespace boosting {
indexVector, l1RegularizationWeight_, l2RegularizationWeight_);
}

std::unique_ptr<IRuleEvaluation<BitDecomposableStatisticVector>> DecomposableCompleteRuleEvaluationFactory::create(
const BitDecomposableStatisticVector& statisticVector, const CompleteIndexVector& indexVector) const {
// TODO Implement
return nullptr;
}

std::unique_ptr<IRuleEvaluation<BitDecomposableStatisticVector>> DecomposableCompleteRuleEvaluationFactory::create(
const BitDecomposableStatisticVector& statisticVector, const PartialIndexVector& indexVector) const {
// TODO Implement
return nullptr;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,18 @@ namespace boosting {
indexVector, l1RegularizationWeight_, l2RegularizationWeight_, std::move(labelBinningPtr));
}

std::unique_ptr<IRuleEvaluation<BitDecomposableStatisticVector>>
DecomposableCompleteBinnedRuleEvaluationFactory::create(const BitDecomposableStatisticVector& statisticVector,
const CompleteIndexVector& indexVector) const {
// TODO Implement
return nullptr;
}

std::unique_ptr<IRuleEvaluation<BitDecomposableStatisticVector>>
DecomposableCompleteBinnedRuleEvaluationFactory::create(const BitDecomposableStatisticVector& statisticVector,
const PartialIndexVector& indexVector) const {
// TODO Implement
return nullptr;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,20 @@ namespace boosting {
indexVector, l1RegularizationWeight_, l2RegularizationWeight_);
}

std::unique_ptr<IRuleEvaluation<BitDecomposableStatisticVector>>
DecomposableDynamicPartialRuleEvaluationFactory::create(const BitDecomposableStatisticVector& statisticVector,
const CompleteIndexVector& indexVector) const {
// TODO Implement
return nullptr;
}

std::unique_ptr<IRuleEvaluation<BitDecomposableStatisticVector>>
DecomposableDynamicPartialRuleEvaluationFactory::create(const BitDecomposableStatisticVector& statisticVector,
const PartialIndexVector& indexVector) const {
// TODO Implement
return nullptr;
}

std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector>>
DecomposableDynamicPartialRuleEvaluationFactory::create(const SparseDecomposableStatisticVector& statisticVector,
const CompleteIndexVector& indexVector) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,20 @@ namespace boosting {
indexVector, l1RegularizationWeight_, l2RegularizationWeight_, std::move(labelBinningPtr));
}

std::unique_ptr<IRuleEvaluation<BitDecomposableStatisticVector>>
DecomposableDynamicPartialBinnedRuleEvaluationFactory::create(
const BitDecomposableStatisticVector& statisticVector, const CompleteIndexVector& indexVector) const {
// TODO Implement
return nullptr;
}

std::unique_ptr<IRuleEvaluation<BitDecomposableStatisticVector>>
DecomposableDynamicPartialBinnedRuleEvaluationFactory::create(
const BitDecomposableStatisticVector& statisticVector, const PartialIndexVector& indexVector) const {
// TODO Implement
return nullptr;
}

std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector>>
DecomposableDynamicPartialBinnedRuleEvaluationFactory::create(
const SparseDecomposableStatisticVector& statisticVector, const CompleteIndexVector& indexVector) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,20 @@ namespace boosting {
indexVector, l1RegularizationWeight_, l2RegularizationWeight_);
}

std::unique_ptr<IRuleEvaluation<BitDecomposableStatisticVector>>
DecomposableFixedPartialRuleEvaluationFactory::create(const BitDecomposableStatisticVector& statisticVector,
const CompleteIndexVector& indexVector) const {
// TODO Implement
return nullptr;
}

std::unique_ptr<IRuleEvaluation<BitDecomposableStatisticVector>>
DecomposableFixedPartialRuleEvaluationFactory::create(const BitDecomposableStatisticVector& statisticVector,
const PartialIndexVector& indexVector) const {
// TODO Implement
return nullptr;
}

std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector>>
DecomposableFixedPartialRuleEvaluationFactory::create(const SparseDecomposableStatisticVector& statisticVector,
const CompleteIndexVector& indexVector) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,20 @@ namespace boosting {
indexVector, l1RegularizationWeight_, l2RegularizationWeight_, std::move(labelBinningPtr));
}

std::unique_ptr<IRuleEvaluation<BitDecomposableStatisticVector>>
DecomposableFixedPartialBinnedRuleEvaluationFactory::create(const BitDecomposableStatisticVector& statisticVector,
const CompleteIndexVector& indexVector) const {
// TODO Implement
return nullptr;
}

std::unique_ptr<IRuleEvaluation<BitDecomposableStatisticVector>>
DecomposableFixedPartialBinnedRuleEvaluationFactory::create(const BitDecomposableStatisticVector& statisticVector,
const PartialIndexVector& indexVector) const {
// TODO Implement
return nullptr;
}

std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector>>
DecomposableFixedPartialBinnedRuleEvaluationFactory::create(
const SparseDecomposableStatisticVector& statisticVector, const CompleteIndexVector& indexVector) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,20 @@ namespace boosting {
indexVector, l1RegularizationWeight_, l2RegularizationWeight_);
}

std::unique_ptr<IRuleEvaluation<BitDecomposableStatisticVector>>
DecomposableSingleOutputRuleEvaluationFactory::create(const BitDecomposableStatisticVector& statisticVector,
const CompleteIndexVector& indexVector) const {
// TODO Implement
return nullptr;
}

std::unique_ptr<IRuleEvaluation<BitDecomposableStatisticVector>>
DecomposableSingleOutputRuleEvaluationFactory::create(const BitDecomposableStatisticVector& statisticVector,
const PartialIndexVector& indexVector) const {
// TODO Implement
return nullptr;
}

std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector>>
DecomposableSingleOutputRuleEvaluationFactory::create(const SparseDecomposableStatisticVector& statisticVector,
const CompleteIndexVector& indexVector) const {
Expand Down

0 comments on commit 7eb512b

Please sign in to comment.