diff --git a/python/boomer/common/cpp/rule_refinement.cpp b/python/boomer/common/cpp/rule_refinement.cpp index 4b5c9ada84..f135778ff5 100644 --- a/python/boomer/common/cpp/rule_refinement.cpp +++ b/python/boomer/common/cpp/rule_refinement.cpp @@ -482,6 +482,72 @@ ApproximateRuleRefinementImpl::ApproximateRuleRefinementImpl(AbstractStatistics* Refinement ApproximateRuleRefinementImpl::findRefinement(IHeadRefinement* headRefinement, PredictionCandidate* currentHead, uint32 numLabelIndices, const uint32* labelIndices) { + uint32 numBins = binArray_->numBins; Refinement refinement; + refinement.featureIndex = featureIndex_; + refinement.head = NULL; + refinement.indexedArray = NULL; + refinement.indexedArrayWrapper = NULL; + + PredictionCandidate* dynamicCurrentHead = currentHead; + + PredictionCandidate* bestHead = currentHead; + + std::unique_ptr statisticsSubsetPtr; + statisticsSubsetPtr.reset(statistics_->createSubset(numLabelIndices, labelIndices)); + + uint32 r = 0; + while(binArray_->bins[r].numExamples == 0 && r < numBins){ + r++; + } + statisticsSubsetPtr.get()->addToSubset(r, 1); + uint32 previousR = r; + float32 previousValue = binArray_->bins[r].maxValue; + uint32 numCoveredExamples = binArray_->bins[r].numExamples; + + r += 1; + for(r; r < numBins; r++){ + uint32 numExamples = binArray_->bins[r].numExamples; + + if(numExamples > 0){ + numCoveredExamples += numExamples; //Das sollten wir, anders wie im Pseudo Code, besser schon hier machen, oder? + float32 currentValue = binArray_->bins[r].minValue; + dynamicCurrentHead = headRefinement->findHead(bestHead, refinement.head, labelIndices, + statisticsSubsetPtr.get(), false, false); + + if(dynamicCurrentHead != NULL){ + bestHead = dynamicCurrentHead; + refinement.comparator = LEQ; + refinement.threshold = (previousValue + currentValue)/2.0; + refinement.start = 0; + refinement.end = r; + refinement.previous = previousR; + refinement.coveredWeights = numCoveredExamples; + refinement.covered = true; + } + + dynamicCurrentHead = headRefinement->findHead(bestHead, refinement.head, labelIndices, + statisticsSubsetPtr.get(), false, false); + + if(dynamicCurrentHead != NULL){ + bestHead = dynamicCurrentHead; + refinement.comparator = GR; + refinement.threshold = (previousValue + currentValue)/2.0; + refinement.start = 0; + refinement.end = r; + refinement.previous = previousR; + refinement.coveredWeights = numCoveredExamples; + refinement.covered = false; + } + + previousValue = binArray_->bins[r].maxValue; + previousR = r; + statisticsSubsetPtr.get()->addToSubset(r, 1); + + } + + } + + return refinement; }