Skip to content

Commit

Permalink
Remove field "indexedArray" from struct Refinement.
Browse files Browse the repository at this point in the history
  • Loading branch information
michael-rapp committed Oct 2, 2020
1 parent 29a7b6e commit 08bf030
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 24 deletions.
7 changes: 3 additions & 4 deletions python/boomer/common/cpp/rule_refinement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@ ExactRuleRefinementImpl::~ExactRuleRefinementImpl() {

void ExactRuleRefinementImpl::findRefinement(IHeadRefinement* headRefinement, PredictionCandidate* currentHead,
uint32 numLabelIndices, const uint32* labelIndices) {
// An array that stores the indices and feature values of the training examples
IndexedFloat32Array* indexedArray = callback_->get(featureIndex_);
bestRefinement_.indexedArray = indexedArray;
// The best head seen so far
PredictionCandidate* bestHead = currentHead;
// Create a new, empty subset of the current statistics when processing a new feature...
std::unique_ptr<IStatisticsSubset> statisticsSubsetPtr;
statisticsSubsetPtr.reset(statistics_->createSubset(numLabelIndices, labelIndices));
// The example indices and feature values to be iterated

// Retrieve the array to be iterated...
IndexedFloat32Array* indexedArray = callback_->get(featureIndex_);
IndexedFloat32* indexedValues = indexedArray->data;
uint32 numIndexedValues = indexedArray->numElements;

Expand Down
3 changes: 2 additions & 1 deletion python/boomer/common/cpp/rule_refinement.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ struct Refinement {
intp start;
intp end;
intp previous;
IndexedFloat32Array* indexedArray;
};

/**
Expand Down Expand Up @@ -58,6 +57,8 @@ class AbstractRuleRefinement {

public:

virtual ~AbstractRuleRefinement() { };

/**
* Finds the best refinement of an existing rule and updates the class attribute `bestRefinement_` accordingly.
*
Expand Down
37 changes: 20 additions & 17 deletions python/boomer/common/cpp/thresholds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,8 @@ static inline intp adjustSplit(IndexedFloat32Array* indexedArray, intp condition
* filtered array is stored in a given struct of type `IndexedFloat32ArrayWrapper` and the given statistics are updated
* accordingly.
*
* @param cacheFiltered A map that maps feature indices to structs of type `IndexedFloat32ArrayWrapper`, storing
* the indices of the training examples that are covered by the existing rule, as well as
* their values for the respective feature, sorted in ascending order by the feature values
* @param featureIndex The index of the feature
* @param indexedArrayWrapper A pointer to a struct of type `IndexedFloat32Array` that should be used to store the
* filtered array
* @param indexedArray A pointer to a struct of type `IndexedFloat32Array` that stores a pointer to the array
* to be filtered, as well as the number of elements in the array
* @param conditionStart The element in `indexedValues` that corresponds to the first example (inclusive)
Expand All @@ -86,12 +84,11 @@ static inline intp adjustSplit(IndexedFloat32Array* indexedArray, intp condition
* @return The value that is used to mark those elements in the updated `coveredExamplesMask` that
* are covered by the new rule
*/
static inline uint32 filterCurrentIndices(std::unordered_map<uint32, IndexedFloat32ArrayWrapper*> &cacheFiltered,
uint32 featureIndex, IndexedFloat32Array* indexedArray, intp conditionStart,
intp conditionEnd, Comparator conditionComparator, bool covered,
uint32 numConditions, uint32* coveredExamplesMask,
uint32 coveredExamplesTarget, AbstractStatistics* statistics,
IWeightVector* weights) {
static inline uint32 filterCurrentIndices(IndexedFloat32ArrayWrapper* indexedArrayWrapper,
IndexedFloat32Array* indexedArray, intp conditionStart, intp conditionEnd,
Comparator conditionComparator, bool covered, uint32 numConditions,
uint32* coveredExamplesMask, uint32 coveredExamplesTarget,
AbstractStatistics* statistics, IWeightVector* weights) {
IndexedFloat32* indexedValues = indexedArray->data;
uint32 numIndexedValues = indexedArray->numElements;
bool descending = conditionEnd < conditionStart;
Expand Down Expand Up @@ -183,7 +180,6 @@ static inline uint32 filterCurrentIndices(std::unordered_map<uint32, IndexedFloa
}
}

IndexedFloat32ArrayWrapper* indexedArrayWrapper = cacheFiltered[featureIndex];
IndexedFloat32Array* filteredIndexedArray = indexedArrayWrapper->array;

if (filteredIndexedArray == NULL) {
Expand Down Expand Up @@ -337,22 +333,29 @@ void ExactThresholdsImpl::ThresholdsSubsetImpl::applyRefinement(Refinement &refi
numRefinements_++;
sumOfWeights_ = refinement.coveredWeights;

uint32 featureIndex = refinement.featureIndex;
IndexedFloat32ArrayWrapper* indexedArrayWrapper = cacheFiltered_[featureIndex];
IndexedFloat32Array* indexedArray = indexedArrayWrapper->array;

if (indexedArray == NULL) {
indexedArray = thresholds_->cache_[featureIndex];
}

// If there are examples with zero weights, those examples have not been considered considered when searching for
// the refinement. In the next step, we need to identify the examples that are covered by the refined rule,
// including those that have previously been ignored, via the function `filterCurrentIndices`. Said function
// calculates the number of covered examples based on the variable `refinement.end`, which represents the position
// that separates the covered from the uncovered examples. However, when taking into account the examples with zero
// weights, this position may differ from the current value of `refinement.end` and therefore must be adjusted...
if (weights_->hasZeroElements() && abs(refinement.previous - refinement.end) > 1) {
refinement.end = adjustSplit(refinement.indexedArray, refinement.end, refinement.previous,
refinement.threshold);
refinement.end = adjustSplit(indexedArray, refinement.end, refinement.previous, refinement.threshold);
}

// Identify the examples that are covered by the refined rule...
coveredExamplesTarget_ = filterCurrentIndices(cacheFiltered_, refinement.featureIndex, refinement.indexedArray,
refinement.start, refinement.end, refinement.comparator,
refinement.covered, numRefinements_, coveredExamplesMask_,
coveredExamplesTarget_, thresholds_->statisticsPtr_.get(), weights_);
coveredExamplesTarget_ = filterCurrentIndices(indexedArrayWrapper, indexedArray, refinement.start, refinement.end,
refinement.comparator, refinement.covered, numRefinements_,
coveredExamplesMask_, coveredExamplesTarget_,
thresholds_->statisticsPtr_.get(), weights_);
}

void ExactThresholdsImpl::ThresholdsSubsetImpl::recalculatePrediction(IHeadRefinement* headRefinement,
Expand Down
2 changes: 0 additions & 2 deletions python/boomer/common/rule_refinement.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
Provides wrappers for classes that allow to find the best refinement of rules.
"""
from boomer.common._arrays cimport uint32, intp, float32
from boomer.common._tuples cimport IndexedFloat32Array
from boomer.common._predictions cimport PredictionCandidate
from boomer.common.rules cimport Comparator
from boomer.common.statistics cimport AbstractStatistics
Expand All @@ -25,7 +24,6 @@ cdef extern from "cpp/rule_refinement.h" nogil:
intp start
intp end
intp previous
IndexedFloat32Array* indexedArray


cdef cppclass AbstractRuleRefinement:
Expand Down

0 comments on commit 08bf030

Please sign in to comment.