Skip to content

Commit

Permalink
Add shared pointer of type RNG as a member of the class StochasticQua…
Browse files Browse the repository at this point in the history
…ntizationMatrix.
  • Loading branch information
michael-rapp committed Sep 13, 2024
1 parent 08c6102 commit e636c93
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 23 deletions.
8 changes: 3 additions & 5 deletions cpp/subprojects/boosting/include/mlrl/boosting/learner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,7 @@ namespace boosting {
* examples.
*/
virtual void useNoQuantization() {
Property<IQuantizationConfig> property = this->getQuantizationConfig();
property.set(std::make_unique<NoQuantizationConfig>());
this->getQuantizationConfig().set(std::make_unique<NoQuantizationConfig>());
}
};

Expand All @@ -343,10 +342,9 @@ namespace boosting {
* configuration of the quantization method
*/
virtual IStochasticQuantizationConfig& useStochasticQuantization() {
Property<IQuantizationConfig> property = this->getQuantizationConfig();
std::unique_ptr<StochasticQuantizationConfig> ptr = std::make_unique<StochasticQuantizationConfig>();
auto ptr = std::make_unique<StochasticQuantizationConfig>(this->getRNGConfig());
IStochasticQuantizationConfig& ref = *ptr;
property.set(std::move(ptr));
this->getQuantizationConfig().set(std::move(ptr));
return ref;
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#pragma once

#include "mlrl/boosting/statistics/quantization.hpp"
#include "mlrl/common/random/rng.hpp"
#include "mlrl/common/util/properties.hpp"

#include <memory>

Expand Down Expand Up @@ -43,11 +45,17 @@ namespace boosting {
public IStochasticQuantizationConfig {
private:

const ReadableProperty<RNGConfig> rngConfig_;

uint32 numBits_;

public:

StochasticQuantizationConfig();
/**
* @param rngConfig A `ReadableProperty` that provides acccess the `RNGConfig` that stores the configuration
* of random number generators
*/
StochasticQuantizationConfig(ReadableProperty<RNGConfig> rngConfig);

uint32 getNumBits() const override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,16 @@ namespace boosting {
class StochasticQuantizationMatrix final : public IQuantizationMatrix<BitDecomposableStatisticView> {
private:

std::shared_ptr<RNG> rngPtr_;

const View& view_;

BitDecomposableStatisticView quantizedView_;

public:

StochasticQuantizationMatrix(const View& view, uint32 numBits)
: view_(view), quantizedView_(view.numRows, view.numCols, numBits) {}
StochasticQuantizationMatrix(std::shared_ptr<RNG> rngPtr, const View& view, uint32 numBits)
: rngPtr_(std::move(rngPtr)), view_(view), quantizedView_(view.numRows, view.numCols, numBits) {}

void quantize(const CompleteIndexVector& outputIndices) override {
// TODO Implement
Expand All @@ -58,57 +60,58 @@ namespace boosting {
const CContiguousView<Tuple<float64>>& statisticMatrix) const override {
return std::make_unique<StochasticQuantization<CContiguousView<Tuple<float64>>>>(
std::make_unique<StochasticQuantizationMatrix<CContiguousView<Tuple<float64>>>>(
statisticMatrix, quantizedView_.firstView.numBitsPerElement));
rngPtr_, statisticMatrix, quantizedView_.firstView.numBitsPerElement));
}

std::unique_ptr<IQuantization> create(const SparseSetView<Tuple<float64>>& statisticMatrix) const override {
return std::make_unique<StochasticQuantization<SparseSetView<Tuple<float64>>>>(
std::make_unique<StochasticQuantizationMatrix<SparseSetView<Tuple<float64>>>>(
statisticMatrix, quantizedView_.firstView.numBitsPerElement));
rngPtr_, statisticMatrix, quantizedView_.firstView.numBitsPerElement));
}

std::unique_ptr<IQuantization> create(
const DenseNonDecomposableStatisticView& statisticMatrix) const override {
return std::make_unique<StochasticQuantization<DenseNonDecomposableStatisticView>>(
std::make_unique<StochasticQuantizationMatrix<DenseNonDecomposableStatisticView>>(
statisticMatrix, quantizedView_.firstView.numBitsPerElement));
rngPtr_, statisticMatrix, quantizedView_.firstView.numBitsPerElement));
}
};

class StochasticQuantizationFactory final : public IQuantizationFactory {
private:

const std::unique_ptr<RNGFactory> rngFactoryPtr_;

uint32 numBits_;

public:

/**
* @param numBits The number of bits to be used for quantized statistics
*/
StochasticQuantizationFactory(uint32 numBits) : numBits_(numBits) {}
StochasticQuantizationFactory(std::unique_ptr<RNGFactory> rngFactoryPtr, uint32 numBits)
: rngFactoryPtr_(std::move(rngFactoryPtr)), numBits_(numBits) {}

std::unique_ptr<IQuantization> create(
const CContiguousView<Tuple<float64>>& statisticMatrix) const override {
return std::make_unique<StochasticQuantization<CContiguousView<Tuple<float64>>>>(
std::make_unique<StochasticQuantizationMatrix<CContiguousView<Tuple<float64>>>>(statisticMatrix,
numBits_));
std::make_unique<StochasticQuantizationMatrix<CContiguousView<Tuple<float64>>>>(
rngFactoryPtr_->create(), statisticMatrix, numBits_));
}

std::unique_ptr<IQuantization> create(const SparseSetView<Tuple<float64>>& statisticMatrix) const override {
return std::make_unique<StochasticQuantization<SparseSetView<Tuple<float64>>>>(
std::make_unique<StochasticQuantizationMatrix<SparseSetView<Tuple<float64>>>>(statisticMatrix,
numBits_));
std::make_unique<StochasticQuantizationMatrix<SparseSetView<Tuple<float64>>>>(
rngFactoryPtr_->create(), statisticMatrix, numBits_));
}

std::unique_ptr<IQuantization> create(
const DenseNonDecomposableStatisticView& statisticMatrix) const override {
return std::make_unique<StochasticQuantization<DenseNonDecomposableStatisticView>>(
std::make_unique<StochasticQuantizationMatrix<DenseNonDecomposableStatisticView>>(statisticMatrix,
numBits_));
std::make_unique<StochasticQuantizationMatrix<DenseNonDecomposableStatisticView>>(
rngFactoryPtr_->create(), statisticMatrix, numBits_));
}
};

StochasticQuantizationConfig::StochasticQuantizationConfig() : numBits_(4) {}
StochasticQuantizationConfig::StochasticQuantizationConfig(ReadableProperty<RNGConfig> rngConfig)
: rngConfig_(rngConfig), numBits_(4) {}

uint32 StochasticQuantizationConfig::getNumBits() const {
return numBits_;
Expand All @@ -121,7 +124,7 @@ namespace boosting {
}

std::unique_ptr<IQuantizationFactory> StochasticQuantizationConfig::createQuantizationFactory() const {
return std::make_unique<StochasticQuantizationFactory>(numBits_);
return std::make_unique<StochasticQuantizationFactory>(rngConfig_.get().createRNGFactory(), numBits_);
}

}

0 comments on commit e636c93

Please sign in to comment.