diff --git a/python/boomer/common/cpp/rule_refinement.cpp b/python/boomer/common/cpp/rule_refinement.cpp index 3946aa6735..043e5a0927 100644 --- a/python/boomer/common/cpp/rule_refinement.cpp +++ b/python/boomer/common/cpp/rule_refinement.cpp @@ -469,3 +469,84 @@ void ExactRuleRefinementImpl::findRefinement(IHeadRefinement* headRefinement, Pr } } } + +ApproximateRuleRefinementImpl::ApproximateRuleRefinementImpl(AbstractStatistics* statistics, uint32 featureIndex, + IRuleRefinementCallback* callback) { + statistics_ = statistics; + featureIndex_ = featureIndex; + callback_ = callback; +} + +~ApproximateRuleRefinementImpl::ApproximateRuleRefinementImpl() { + delete callback_; +} + +void ApproximateRuleRefinementImpl::findRefinement(IHeadRefinement* headRefinement, PredictionCandidate* currentHead, + uint32 numLabelIndices, const uint32* labelIndices) { + BinArray* binArray = callback_->get(0); + uint32 numBins = binArray->numBins; + Refinement refinement; + refinement.featureIndex = featureIndex_; + refinement.head = NULL; + refinement.start = 0; + + PredictionCandidate* bestHead = currentHead; + + std::unique_ptr statisticsSubsetPtr; + statisticsSubsetPtr.reset(statistics_->createSubset(numLabelIndices, labelIndices)); + + //Search for the first not empty bin + uint32 r = 0; + + while (binArray->bins[r].numExamples == 0 && r < numBins) { + r++; + } + + statisticsSubsetPtr.get()->addToSubset(r, 1); + uint32 previousR = r; + float32 previousValue = binArray->bins[r].maxValue; + uint32 numCoveredExamples = binArray->bins[r].numExamples; + + for (r = r + 1; r < numBins; r++) { + uint32 numExamples = binArray->bins[r].numExamples; + + if (numExamples > 0) { + float32 currentValue = binArray->bins[r].minValue; + + PredictionCandidate* currentHead = headRefinement->findHead(bestHead, refinement.head, labelIndices, + statisticsSubsetPtr.get(), false, false); + + if (currentHead != NULL) { + bestHead = currentHead; + refinement.head = currentHead; + refinement.comparator = LEQ; + refinement.threshold = (previousValue + currentValue) / 2.0; + refinement.end = r; + refinement.previous = previousR; + refinement.coveredWeights = numCoveredExamples; + refinement.covered = true; + } + + currentHead = headRefinement->findHead(bestHead, refinement.head, labelIndices, statisticsSubsetPtr.get(), + true, false); + + if (currentHead != NULL) { + bestHead = currentHead; + refinement.head = currentHead; + refinement.comparator = GR; + refinement.threshold = (previousValue + currentValue) / 2.0; + refinement.end = r; + refinement.previous = previousR; + refinement.coveredWeights = numCoveredExamples; + refinement.covered = false; + } + + previousValue = binArray->bins[r].maxValue; + previousR = r; + numCoveredExamples += numExamples; + statisticsSubsetPtr.get()->addToSubset(r, 1); + } + } + + bestRefinement_ = refinement; +} diff --git a/python/boomer/common/cpp/rule_refinement.h b/python/boomer/common/cpp/rule_refinement.h index 8c25c347c6..8f378a740c 100644 --- a/python/boomer/common/cpp/rule_refinement.h +++ b/python/boomer/common/cpp/rule_refinement.h @@ -2,6 +2,7 @@ * Implements classes that allow to find the best refinement of rules. * * @author Michael Rapp (mrapp@ke.tu-darmstadt.de) + * @author Lukas Johannes Eberle (lukasjohannes.eberle@stud.tu-darmstadt.de) */ #pragma once @@ -126,3 +127,39 @@ class ExactRuleRefinementImpl : public AbstractRuleRefinement { const uint32* labelIndices) override; }; + +/** + * Allows to find the best refinements of existing rules, which result from adding a new condition that correspond to a + * certain feature. The thresholds that may be used by the new condition result from the bins that have been created + * using a binning method. + */ +class ApproximateRuleRefinementImpl : public AbstractRuleRefinement { + + private: + + AbstractStatistics* statistics_; + + BinArray* binArray_; + + uint32 featureIndex_; + + IRuleRefinementCallback* callback_; + + public: + + /** + * @param statistics A pointer to an object of type `AbstractStatistics` that provides access to the + * statistics which serve as the basis for evaluating the potential refinements of rules + * @param featureIndex The index of the feature, the new condition corresponds to + * @param callback A pointer to an object of type `IRuleRefinementCallback` that allows to + * retrieve the information that is required to identify potential refinements + */ + ApproximateRuleRefinementImpl(AbstractStatistics* statistics, uint32 featureIndex, + IRuleRefinementCallback* callback); + + ~ApproximateRuleRefinementImpl(); + + void findRefinement(IHeadRefinement* headRefinement, PredictionCandidate* currentHead, + uint32 numLabelIndices, const uint32* labelIndices) override; + +}; diff --git a/python/boomer/common/cpp/tuples.h b/python/boomer/common/cpp/tuples.h index d5585a0c89..75c4a1d184 100644 --- a/python/boomer/common/cpp/tuples.h +++ b/python/boomer/common/cpp/tuples.h @@ -56,6 +56,15 @@ struct Bin { float32 maxValue; }; +/** + * A structs that contains a pointer to an array of type `Bin`. The attribute `numBins` specifies how many elements the + * array contains. + */ +struct BinArray { + uint32 numBins; + Bin* bins; +}; + namespace tuples { /**