Skip to content

Commit

Permalink
Add constructor argument "view" to vectors for storing statistics.
Browse files Browse the repository at this point in the history
  • Loading branch information
michael-rapp committed Oct 2, 2024
1 parent 9d7d048 commit 7070088
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ namespace boosting {
public:

/**
* @param numElements The number of gradients and Hessians in the vector
* @param numBitsPerElement The number of bits per element in the bit vector
* @param init True, if all gradients and Hessians in the vector should be initialized with
* zero, false otherwise
* @param view A reference to an object of type `BitDecomposableStatisticView`
* @param numElements The number of gradients and Hessians in the vector
* @param init True, if all gradients and Hessians in the vector should be initialized with zero,
* false otherwise
*/
BitDecomposableStatisticVector(uint32 numElements, uint32 numBitsPerElement, bool init = false);
BitDecomposableStatisticVector(const BitDecomposableStatisticView& view, uint32 numElements,
bool init = false);

/**
* @param other A reference to an object of type `BitDecomposableStatisticVector` to be copied
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ namespace boosting {
public:

/**
* @param view A reference to an object of type `CContiguousView`
* @param numElements The number of gradients and Hessians in the vector
* @param init True, if all gradients and Hessians in the vector should be initialized with zero,
* false otherwise
*/
DenseDecomposableStatisticVector(uint32 numElements, bool init = false);
DenseDecomposableStatisticVector(const CContiguousView<Tuple<float64>>& view, uint32 numElements,
bool init = false);

/**
* @param other A reference to an object of type `DenseDecomposableStatisticVector` to be copied
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,13 @@ namespace boosting {
public:

/**
* @param view A reference to an object of type `SparseSetView`
* @param numElements The number of gradients and Hessians in the vector
* @param init True, if all gradients and Hessians in the vector should be initialized with zero,
* false otherwise
*/
SparseDecomposableStatisticVector(uint32 numElements, bool init = false);
SparseDecomposableStatisticVector(const SparseSetView<Tuple<float64>>& view, uint32 numElements,
bool init = false);

/**
* @param other A reference to an object of type `SparseDecomposableStatisticVector` to be copied
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ namespace boosting {
public:

/**
* @param numGradients The number of gradients in the vector
* @param init True, if all gradients and Hessians in the vector should be initialized with zero,
* false otherwise
* @param view A reference to an object of type `DenseNonDecomposableStatisticView`
* @param numGradients The number of gradients in the vector
* @param init True, if all gradients and Hessians in the vector should be initialized with zero,
* false otherwise
*/
DenseNonDecomposableStatisticVector(uint32 numGradients, bool init = false);
DenseNonDecomposableStatisticVector(const DenseNonDecomposableStatisticView& view, uint32 numGradients,
bool init = false);

/**
* @param other A reference to an object of type `DenseNonDecomposableStatisticVector` to be copied
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,16 @@ namespace boosting {
}
}

BitDecomposableStatisticVector::BitDecomposableStatisticVector(uint32 numElements, uint32 numBitsPerElement,
bool init)
BitDecomposableStatisticVector::BitDecomposableStatisticVector(const BitDecomposableStatisticView& view,
uint32 numElements, bool init)
: CompositeView<AllocatedBitVector<uint32>, AllocatedBitVector<uint32>>(
AllocatedBitVector<uint32>(numElements, numBitsPerElement, init),
AllocatedBitVector<uint32>(numElements, numBitsPerElement, init)) {}
AllocatedBitVector<uint32>(numElements, view.firstView.numBitsPerElement, init),
AllocatedBitVector<uint32>(numElements, view.secondView.numBitsPerElement, init)) {}

BitDecomposableStatisticVector::BitDecomposableStatisticVector(const BitDecomposableStatisticVector& other)
: BitDecomposableStatisticVector(other.getNumElements(), other.getNumBitsPerElement()) {
: CompositeView<AllocatedBitVector<uint32>, AllocatedBitVector<uint32>>(
AllocatedBitVector<uint32>(other.firstView.numElements, other.firstView.numBitsPerElement),
AllocatedBitVector<uint32>(other.secondView.numElements, other.secondView.numBitsPerElement)) {
copyInternally(other.firstView, this->firstView);
copyInternally(other.secondView, this->secondView);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

namespace boosting {

DenseDecomposableStatisticVector::DenseDecomposableStatisticVector(uint32 numElements, bool init)
DenseDecomposableStatisticVector::DenseDecomposableStatisticVector(const CContiguousView<Tuple<float64>>& view,
uint32 numElements, bool init)
: ClearableViewDecorator<DenseVectorDecorator<AllocatedVector<Tuple<float64>>>>(
AllocatedVector<Tuple<float64>>(numElements, init)) {}

DenseDecomposableStatisticVector::DenseDecomposableStatisticVector(const DenseDecomposableStatisticVector& other)
: DenseDecomposableStatisticVector(other.getNumElements()) {
: ClearableViewDecorator<DenseVectorDecorator<AllocatedVector<Tuple<float64>>>>(
AllocatedVector<Tuple<float64>>(other.getNumElements())) {
util::copyView(other.cbegin(), this->begin(), this->getNumElements());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,15 @@ namespace boosting {
return iterator_ - rhs.iterator_;
}

SparseDecomposableStatisticVector::SparseDecomposableStatisticVector(uint32 numElements, bool init)
SparseDecomposableStatisticVector::SparseDecomposableStatisticVector(const SparseSetView<Tuple<float64>>& view,
uint32 numElements, bool init)
: ClearableViewDecorator<VectorDecorator<AllocatedVector<Triple<float64>>>>(
AllocatedVector<Triple<float64>>(numElements, init)),
sumOfWeights_(0) {}

SparseDecomposableStatisticVector::SparseDecomposableStatisticVector(const SparseDecomposableStatisticVector& other)
: SparseDecomposableStatisticVector(other.getNumElements()) {
: ClearableViewDecorator<VectorDecorator<AllocatedVector<Triple<float64>>>>(
AllocatedVector<Triple<float64>>(other.getNumElements())) {
util::copyView(other.view.cbegin(), this->view.begin(), this->getNumElements());
sumOfWeights_ = other.sumOfWeights_;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@

namespace boosting {

DenseNonDecomposableStatisticVector::DenseNonDecomposableStatisticVector(uint32 numGradients, bool init)
DenseNonDecomposableStatisticVector::DenseNonDecomposableStatisticVector(
const DenseNonDecomposableStatisticView& view, uint32 numGradients, bool init)
: ClearableViewDecorator<ViewDecorator<CompositeVector<AllocatedVector<float64>, AllocatedVector<float64>>>>(
CompositeVector<AllocatedVector<float64>, AllocatedVector<float64>>(
AllocatedVector<float64>(numGradients, init),
AllocatedVector<float64>(util::triangularNumber(numGradients), init))) {}

DenseNonDecomposableStatisticVector::DenseNonDecomposableStatisticVector(
const DenseNonDecomposableStatisticVector& other)
: DenseNonDecomposableStatisticVector(other.getNumGradients()) {
: ClearableViewDecorator<ViewDecorator<CompositeVector<AllocatedVector<float64>, AllocatedVector<float64>>>>(
CompositeVector<AllocatedVector<float64>, AllocatedVector<float64>>(
AllocatedVector<float64>(other.getNumGradients()), AllocatedVector<float64>(other.getNumHessians()))) {
util::copyView(other.gradients_cbegin(), this->gradients_begin(), this->getNumGradients());
util::copyView(other.hessians_cbegin(), this->hessians_begin(), this->getNumHessians());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ namespace boosting {
*/
StatisticsSubset(const StatisticView& statisticView, const RuleEvaluationFactory& ruleEvaluationFactory,
const WeightVector& weights, const IndexVector& outputIndices)
: sumVector_(outputIndices.getNumElements(), true), statisticView_(statisticView), weights_(weights),
outputIndices_(outputIndices),
: sumVector_(statisticView, outputIndices.getNumElements(), true), statisticView_(statisticView),
weights_(weights), outputIndices_(outputIndices),
ruleEvaluationPtr_(ruleEvaluationFactory.create(sumVector_, outputIndices)) {}

/**
Expand Down Expand Up @@ -182,7 +182,8 @@ namespace boosting {
: StatisticsSubset<StatisticVector, StatisticView, RuleEvaluationFactory, WeightVector,
IndexVector>(statistics.statisticView_, statistics.ruleEvaluationFactory_,
statistics.weights_, outputIndices),
tmpVector_(outputIndices.getNumElements()), totalSumVector_(&totalSumVector) {}
tmpVector_(statistics.statisticView_, outputIndices.getNumElements()),
totalSumVector_(&totalSumVector) {}

/**
* @see `IResettableStatisticsSubset::resetSubset`
Expand Down Expand Up @@ -392,7 +393,7 @@ namespace boosting {
const WeightVector& weights)
: AbstractWeightedStatistics<StatisticVector, StatisticView, RuleEvaluationFactory, WeightVector>(
statisticView, ruleEvaluationFactory, weights),
totalSumVectorPtr_(std::make_unique<StatisticVector>(statisticView.numCols, true)) {
totalSumVectorPtr_(std::make_unique<StatisticVector>(statisticView, statisticView.numCols, true)) {
uint32 numStatistics = weights.getNumElements();

for (uint32 i = 0; i < numStatistics; i++) {
Expand Down

0 comments on commit 7070088

Please sign in to comment.