diff --git a/cpp/subprojects/boosting/include/mlrl/boosting/learner.hpp b/cpp/subprojects/boosting/include/mlrl/boosting/learner.hpp index c731da927..ec232e7b4 100644 --- a/cpp/subprojects/boosting/include/mlrl/boosting/learner.hpp +++ b/cpp/subprojects/boosting/include/mlrl/boosting/learner.hpp @@ -321,8 +321,7 @@ namespace boosting { * examples. */ virtual void useNoQuantization() { - Property property = this->getQuantizationConfig(); - property.set(std::make_unique()); + this->getQuantizationConfig().set(std::make_unique()); } }; @@ -343,10 +342,9 @@ namespace boosting { * configuration of the quantization method */ virtual IStochasticQuantizationConfig& useStochasticQuantization() { - Property property = this->getQuantizationConfig(); - std::unique_ptr ptr = std::make_unique(); + auto ptr = std::make_unique(this->getRNGConfig()); IStochasticQuantizationConfig& ref = *ptr; - property.set(std::move(ptr)); + this->getQuantizationConfig().set(std::move(ptr)); return ref; } }; diff --git a/cpp/subprojects/boosting/include/mlrl/boosting/statistics/quantization_stochastic.hpp b/cpp/subprojects/boosting/include/mlrl/boosting/statistics/quantization_stochastic.hpp index 7ff50a574..b6478dfe5 100644 --- a/cpp/subprojects/boosting/include/mlrl/boosting/statistics/quantization_stochastic.hpp +++ b/cpp/subprojects/boosting/include/mlrl/boosting/statistics/quantization_stochastic.hpp @@ -4,6 +4,8 @@ #pragma once #include "mlrl/boosting/statistics/quantization.hpp" +#include "mlrl/common/random/rng.hpp" +#include "mlrl/common/util/properties.hpp" #include @@ -43,11 +45,17 @@ namespace boosting { public IStochasticQuantizationConfig { private: + const ReadableProperty rngConfig_; + uint32 numBits_; public: - StochasticQuantizationConfig(); + /** + * @param rngConfig A `ReadableProperty` that provides acccess the `RNGConfig` that stores the configuration + * of random number generators + */ + StochasticQuantizationConfig(ReadableProperty rngConfig); uint32 getNumBits() const override; diff --git a/cpp/subprojects/boosting/src/mlrl/boosting/statistics/quantization_stochastic.cpp b/cpp/subprojects/boosting/src/mlrl/boosting/statistics/quantization_stochastic.cpp index 7c0799448..ac90bc4aa 100644 --- a/cpp/subprojects/boosting/src/mlrl/boosting/statistics/quantization_stochastic.cpp +++ b/cpp/subprojects/boosting/src/mlrl/boosting/statistics/quantization_stochastic.cpp @@ -33,14 +33,16 @@ namespace boosting { class StochasticQuantizationMatrix final : public IQuantizationMatrix { private: + std::shared_ptr rngPtr_; + const View& view_; BitDecomposableStatisticView quantizedView_; public: - StochasticQuantizationMatrix(const View& view, uint32 numBits) - : view_(view), quantizedView_(view.numRows, view.numCols, numBits) {} + StochasticQuantizationMatrix(std::shared_ptr rngPtr, const View& view, uint32 numBits) + : rngPtr_(std::move(rngPtr)), view_(view), quantizedView_(view.numRows, view.numCols, numBits) {} void quantize(const CompleteIndexVector& outputIndices) override { // TODO Implement @@ -58,57 +60,58 @@ namespace boosting { const CContiguousView>& statisticMatrix) const override { return std::make_unique>>>( std::make_unique>>>( - statisticMatrix, quantizedView_.firstView.numBitsPerElement)); + rngPtr_, statisticMatrix, quantizedView_.firstView.numBitsPerElement)); } std::unique_ptr create(const SparseSetView>& statisticMatrix) const override { return std::make_unique>>>( std::make_unique>>>( - statisticMatrix, quantizedView_.firstView.numBitsPerElement)); + rngPtr_, statisticMatrix, quantizedView_.firstView.numBitsPerElement)); } std::unique_ptr create( const DenseNonDecomposableStatisticView& statisticMatrix) const override { return std::make_unique>( std::make_unique>( - statisticMatrix, quantizedView_.firstView.numBitsPerElement)); + rngPtr_, statisticMatrix, quantizedView_.firstView.numBitsPerElement)); } }; class StochasticQuantizationFactory final : public IQuantizationFactory { private: + const std::unique_ptr rngFactoryPtr_; + uint32 numBits_; public: - /** - * @param numBits The number of bits to be used for quantized statistics - */ - StochasticQuantizationFactory(uint32 numBits) : numBits_(numBits) {} + StochasticQuantizationFactory(std::unique_ptr rngFactoryPtr, uint32 numBits) + : rngFactoryPtr_(std::move(rngFactoryPtr)), numBits_(numBits) {} std::unique_ptr create( const CContiguousView>& statisticMatrix) const override { return std::make_unique>>>( - std::make_unique>>>(statisticMatrix, - numBits_)); + std::make_unique>>>( + rngFactoryPtr_->create(), statisticMatrix, numBits_)); } std::unique_ptr create(const SparseSetView>& statisticMatrix) const override { return std::make_unique>>>( - std::make_unique>>>(statisticMatrix, - numBits_)); + std::make_unique>>>( + rngFactoryPtr_->create(), statisticMatrix, numBits_)); } std::unique_ptr create( const DenseNonDecomposableStatisticView& statisticMatrix) const override { return std::make_unique>( - std::make_unique>(statisticMatrix, - numBits_)); + std::make_unique>( + rngFactoryPtr_->create(), statisticMatrix, numBits_)); } }; - StochasticQuantizationConfig::StochasticQuantizationConfig() : numBits_(4) {} + StochasticQuantizationConfig::StochasticQuantizationConfig(ReadableProperty rngConfig) + : rngConfig_(rngConfig), numBits_(4) {} uint32 StochasticQuantizationConfig::getNumBits() const { return numBits_; @@ -121,7 +124,7 @@ namespace boosting { } std::unique_ptr StochasticQuantizationConfig::createQuantizationFactory() const { - return std::make_unique(numBits_); + return std::make_unique(rngConfig_.get().createRNGFactory(), numBits_); } }