Skip to content

Commit

Permalink
Merge pull request QMCPACK#4496 from quantumsteve/uptr_sample_stack
Browse files Browse the repository at this point in the history
Remove raw owning pointers in SampleStack
  • Loading branch information
ye-luo authored Mar 3, 2023
2 parents 7a3f3da + 3f9b202 commit 4148d9f
Show file tree
Hide file tree
Showing 10 changed files with 30 additions and 82 deletions.
2 changes: 1 addition & 1 deletion src/Particle/MCWalkerConfiguration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ MCWalkerConfiguration::MCWalkerConfiguration(const MCWalkerConfiguration& mcw)
samples.clearEnsemble();
samples.setMaxSamples(mcw.getMaxSamples());
setWalkerOffsets(mcw.getWalkerOffsets());
Properties = mcw.Properties;
Properties = mcw.Properties;
}

MCWalkerConfiguration::~MCWalkerConfiguration() = default;
Expand Down
47 changes: 7 additions & 40 deletions src/Particle/SampleStack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,82 +11,49 @@


#include "SampleStack.h"
#include "Particle/MCSample.h"
#include "Utilities/IteratorUtility.h"

namespace qmcplusplus
{
SampleStack::SampleStack() : total_num_(0), max_samples_(10), current_sample_count_(0) {}

/** allocate the SampleStack
* @param n number of samples per rank
* @param num_ranks number of ranks. Used to set global number of samples.
*/
void SampleStack::setMaxSamples(int n, int num_ranks)
void SampleStack::setMaxSamples(size_t n, size_t num_ranks)
{
max_samples_ = n;
global_num_samples_ = n * num_ranks;
//do not add anything
if (n == 0)
return;
sample_vector_.reserve(n);
int nadd = n - sample_vector_.size();
while (nadd > 0)
{
sample_vector_.push_back(new MCSample(total_num_));
--nadd;
}
current_sample_count_ = std::min(current_sample_count_, max_samples_);
sample_vector_.resize(n, MCSample(0));
}

MCSample& SampleStack::getSample(unsigned int i) const { return *sample_vector_[i]; }

void SampleStack::saveEnsemble(std::vector<MCSample>& walker_list)
{
//safety check
if (max_samples_ == 0)
return;
auto first = walker_list.begin();
auto last = walker_list.end();
while ((first != last) && (current_sample_count_ < max_samples_))
{
*sample_vector_[current_sample_count_] = *first;
++first;
++current_sample_count_;
}
}
const MCSample& SampleStack::getSample(size_t i) const { return sample_vector_[i]; }

void SampleStack::appendSample(MCSample&& sample)
{
// Ignore samples in excess of the expected number of samples
if (current_sample_count_ < max_samples_)
{
*sample_vector_[current_sample_count_] = std::move(sample);
sample_vector_[current_sample_count_] = std::move(sample);
current_sample_count_++;
}
}


/** load a single sample from SampleStack
*/
void SampleStack::loadSample(ParticleSet& pset, size_t iw) const
{
pset.R = sample_vector_[iw]->R;
pset.spins = sample_vector_[iw]->spins;
pset.R = sample_vector_[iw].R;
pset.spins = sample_vector_[iw].spins;
}

void SampleStack::clearEnsemble()
{
//delete_iter(SampleStack.begin(),SampleStack.end());
for (int i = 0; i < sample_vector_.size(); ++i)
if (sample_vector_[i])
delete sample_vector_[i];
sample_vector_.clear();
max_samples_ = 0;
current_sample_count_ = 0;
}

SampleStack::~SampleStack() { clearEnsemble(); }

void SampleStack::resetSampleCount() { current_sample_count_ = 0; }


Expand Down
30 changes: 10 additions & 20 deletions src/Particle/SampleStack.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,29 @@

#include <vector>
#include "Particle/ParticleSet.h"
#include "Particle/MCSample.h"
#include "Particle/Walker.h"
#include "Particle/WalkerConfigurations.h"

namespace qmcplusplus
{
struct MCSample;

class SampleStack
{
public:
using PropertySetType = QMCTraits::PropertySetType;

SampleStack();

void setTotalNum(int total_num) { total_num_ = total_num; }

int getMaxSamples() const { return max_samples_; }
size_t getMaxSamples() const { return max_samples_; }

bool empty() const { return sample_vector_.empty(); }

MCSample& getSample(unsigned int i) const;
const MCSample& getSample(size_t i) const;

//@{save/load/clear function for optimization
inline int getNumSamples() const { return current_sample_count_; }
inline size_t getNumSamples() const { return current_sample_count_; }
///set the number of max samples per rank.
void setMaxSamples(int n, int number_of_ranks = 1);
void setMaxSamples(size_t n, size_t number_of_ranks = 1);
/// Global number of samples is number of samples per rank * number of ranks
uint64_t getGlobalNumSamples() const { return global_num_samples_; }
///save the position of current walkers
void saveEnsemble(std::vector<MCSample>& walker_list);
size_t getGlobalNumSamples() const { return global_num_samples_; }
/// load a single sample from SampleStack
void loadSample(ParticleSet& pset, size_t iw) const;

Expand All @@ -61,15 +54,12 @@ class SampleStack
/// Set the sample count to zero but preserve the storage
void resetSampleCount();

~SampleStack();

private:
int total_num_;
int max_samples_;
int current_sample_count_;
uint64_t global_num_samples_;
size_t max_samples_{10};
size_t current_sample_count_{0};
size_t global_num_samples_{max_samples_};

std::vector<MCSample*> sample_vector_;
std::vector<MCSample> sample_vector_;
};


Expand Down
1 change: 0 additions & 1 deletion src/Particle/tests/test_sample_stack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ TEST_CASE("SampleStack", "[particle]")
SampleStack samples;

const int total_num = 2; // number of particles
samples.setTotalNum(total_num);

// reserve storage
int nranks = 2;
Expand Down
8 changes: 4 additions & 4 deletions src/QMCDrivers/LMYEngineInterface/LMYE_QMCCostFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@

namespace qmcplusplus
{
int QMCCostFunction::total_samples()
size_t QMCCostFunction::total_samples()
{
// for the unfamiliar, the [] starts a lambda function
return std::accumulate(wClones.begin(), wClones.begin() + NumThreads, 0,
[](int x, const auto& p) { return x + p->numSamples(); });
return std::accumulate(wClones.begin(), wClones.begin() + NumThreads, size_t{0},
[](size_t x, const auto& p) { return x + p->numSamples(); });
}

///////////////////////////////////////////////////////////////////////////////////////////////////
Expand All @@ -35,7 +35,7 @@ int QMCCostFunction::total_samples()
QMCCostFunction::Return_rt QMCCostFunction::LMYEngineCost_detail(cqmc::engine::LMYEngine<ValueType>* EngineObj)
{
// get total number of samples
const int m = this->total_samples();
const size_t m = this->total_samples();

// reset Engine object
EngineObj->reset();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@

namespace qmcplusplus
{
int QMCCostFunctionBatched::total_samples()
{
return samples_.getGlobalNumSamples();
}
size_t QMCCostFunctionBatched::total_samples() { return samples_.getGlobalNumSamples(); }

///////////////////////////////////////////////////////////////////////////////////////////////////
/// \brief Computes the cost function using the LMYEngine for interfacing with batched driver
Expand All @@ -32,7 +29,7 @@ QMCCostFunctionBatched::Return_rt QMCCostFunctionBatched::LMYEngineCost_detail(
cqmc::engine::LMYEngine<Return_t>* EngineObj)
{
// get total number of samples
const int m = this->total_samples();
const size_t m = this->total_samples();
// reset Engine object
EngineObj->reset();

Expand Down
2 changes: 1 addition & 1 deletion src/QMCDrivers/WFOpt/QMCCostFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class QMCCostFunction : public QMCCostFunctionBase, public CloneManager
EffectiveWeight correlatedSampling(bool needGrad = true) override;

#ifdef HAVE_LMY_ENGINE
int total_samples();
size_t total_samples();
Return_rt LMYEngineCost_detail(cqmc::engine::LMYEngine<Return_t>* EngineObj) override;
#endif

Expand Down
2 changes: 1 addition & 1 deletion src/QMCDrivers/WFOpt/QMCCostFunctionBatched.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class QMCCostFunctionBatched : public QMCCostFunctionBase, public QMCTraits


#ifdef HAVE_LMY_ENGINE
int total_samples();
size_t total_samples();
Return_rt LMYEngineCost_detail(cqmc::engine::LMYEngine<Return_t>* EngineObj) override;
#endif

Expand Down
1 change: 0 additions & 1 deletion src/QMCDrivers/tests/QMCDriverNewTestWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class QMCDriverNewTestWrapper : public QMCDriverNew
QMCDriverInput&& input,
WalkerConfigurations& wc,
MCPopulation&& population,
SampleStack samples,
Communicate* comm)
: QMCDriverNew(test_project,
std::move(input),
Expand Down
12 changes: 4 additions & 8 deletions src/QMCDrivers/tests/test_QMCDriverNew.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,11 @@ TEST_CASE("QMCDriverNew tiny case", "[drivers]")
auto wavefunction_pool = MinimalWaveFunctionPool::make_diamondC_1x1x1(comm, particle_pool);

auto hamiltonian_pool = MinimalHamiltonianPool::make_hamWithEE(comm, particle_pool, wavefunction_pool);
SampleStack samples;
WalkerConfigurations walker_confs;
QMCDriverNewTestWrapper qmcdriver(test_project, std::move(qmcdriver_input), walker_confs,
MCPopulation(comm->size(), comm->rank(), particle_pool.getParticleSet("e"),
wavefunction_pool.getPrimary(), hamiltonian_pool.getPrimary()),
samples, comm);
comm);

// setStatus must be called before process
std::string root_name{"Test"};
Expand Down Expand Up @@ -97,12 +96,11 @@ TEST_CASE("QMCDriverNew more crowds than threads", "[drivers]")

ProjectData test_project("", ProjectData::DriverVersion::BATCH);
QMCDriverInput qmcdriver_copy(qmcdriver_input);
SampleStack samples;
WalkerConfigurations walker_confs;
QMCDriverNewTestWrapper qmc_batched(test_project, std::move(qmcdriver_copy), walker_confs,
MCPopulation(comm->size(), comm->rank(), particle_pool.getParticleSet("e"),
wavefunction_pool.getPrimary(), hamiltonian_pool.getPrimary()),
samples, comm);
comm);
QMCDriverNewTestWrapper::TestNumCrowdsVsNumThreads<ParallelExecutor<>> testNumCrowds;
testNumCrowds(9);
testNumCrowds(8);
Expand Down Expand Up @@ -136,12 +134,11 @@ TEST_CASE("QMCDriverNew walker counts", "[drivers]")

ProjectData test_project("", ProjectData::DriverVersion::BATCH);
QMCDriverInput qmcdriver_copy(qmcdriver_input);
SampleStack samples;
WalkerConfigurations walker_confs;
QMCDriverNewTestWrapper qmc_batched(test_project, std::move(qmcdriver_copy), walker_confs,
MCPopulation(comm->size(), comm->rank(), particle_pool.getParticleSet("e"),
wavefunction_pool.getPrimary(), hamiltonian_pool.getPrimary()),
samples, comm);
comm);

qmc_batched.testAdjustGlobalWalkerCount();
}
Expand All @@ -165,12 +162,11 @@ TEST_CASE("QMCDriverNew test driver operations", "[drivers]")
auto wavefunction_pool = MinimalWaveFunctionPool::make_diamondC_1x1x1(comm, particle_pool);

auto hamiltonian_pool = MinimalHamiltonianPool::make_hamWithEE(comm, particle_pool, wavefunction_pool);
SampleStack samples;
WalkerConfigurations walker_confs;
QMCDriverNewTestWrapper qmcdriver(test_project, std::move(qmcdriver_input), walker_confs,
MCPopulation(comm->size(), comm->rank(), particle_pool.getParticleSet("e"),
wavefunction_pool.getPrimary(), hamiltonian_pool.getPrimary()),
samples, comm);
comm);


auto tau = 1.0;
Expand Down

0 comments on commit 4148d9f

Please sign in to comment.