-
Notifications
You must be signed in to change notification settings - Fork 141
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4492 from ye-luo/restrict-eref-memory
Restrict Eref update history
- Loading branch information
Showing
14 changed files
with
357 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
////////////////////////////////////////////////////////////////////////////////////// | ||
// This file is distributed under the University of Illinois/NCSA Open Source License. | ||
// See LICENSE file in top directory for details. | ||
// | ||
// Copyright (c) 2023 QMCPACK developers. | ||
// | ||
// File developed by: Ye Luo, [email protected], Argonne National Laboratory | ||
// | ||
// File created by: Ye Luo, [email protected], Argonne National Laboratory | ||
////////////////////////////////////////////////////////////////////////////////////// | ||
// -*- C++ -*- | ||
#ifndef QMCPLUSPLUS_SIZELIMITEDDATAQUEUE_H | ||
#define QMCPLUSPLUS_SIZELIMITEDDATAQUEUE_H | ||
|
||
#include <deque> | ||
#include <array> | ||
#include <cassert> | ||
|
||
namespace qmcplusplus | ||
{ | ||
|
||
/** collect data with a history limit. | ||
* data stored in std::deque<std::array<T, NUM_FIELDS>> | ||
*/ | ||
template<typename T, size_t NUM_FIELDS> | ||
class SizeLimitedDataQueue | ||
{ | ||
public: | ||
struct HistoryElement | ||
{ | ||
T weight; | ||
std::array<T, NUM_FIELDS> properties; | ||
}; | ||
|
||
using value_type = HistoryElement; | ||
|
||
SizeLimitedDataQueue(size_t size_limit) : size_limit_(size_limit) {} | ||
|
||
/// add a new record | ||
void push(const value_type& val) | ||
{ | ||
if (data.size() == size_limit_) | ||
data.pop_front(); | ||
assert(data.size() < size_limit_); | ||
data.push_back(val); | ||
} | ||
|
||
/// add a new record | ||
void push(value_type&& val) | ||
{ | ||
if (data.size() == size_limit_) | ||
data.pop_front(); | ||
assert(data.size() < size_limit_); | ||
data.push_back(val); | ||
} | ||
|
||
/// return weighted average | ||
auto weighted_avg() const | ||
{ | ||
std::array<T, NUM_FIELDS> avg; | ||
std::fill(avg.begin(), avg.end(), T(0)); | ||
T weight_sum = 0; | ||
for (auto& element : data) | ||
{ | ||
weight_sum += element.weight; | ||
for (size_t i = 0; i < NUM_FIELDS; i++) | ||
avg[i] += element.properties[i] * element.weight; | ||
} | ||
for (size_t i = 0; i < NUM_FIELDS; i++) | ||
avg[i] /= weight_sum; | ||
return avg; | ||
} | ||
|
||
/// return the number of records | ||
auto size() const { return data.size(); } | ||
|
||
private: | ||
std::deque<value_type> data; | ||
const size_t size_limit_; | ||
}; | ||
|
||
} // namespace qmcplusplus | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
////////////////////////////////////////////////////////////////////////////////////// | ||
// This file is distributed under the University of Illinois/NCSA Open Source License. | ||
// See LICENSE file in top directory for details. | ||
// | ||
// Copyright (c) 2023 QMCPACK developers. | ||
// | ||
// File developed by: Ye Luo, [email protected], Argonne National Laboratory | ||
// | ||
// File created by: Ye Luo, [email protected], Argonne National Laboratory | ||
////////////////////////////////////////////////////////////////////////////////////// | ||
|
||
#include "catch.hpp" | ||
#include "SizeLimitedDataQueue.hpp" | ||
|
||
namespace qmcplusplus | ||
{ | ||
|
||
TEST_CASE("SizeLimitedDataQueue", "[estimators]") | ||
{ | ||
SizeLimitedDataQueue<double, 1> weight_and_energy(3); | ||
CHECK(weight_and_energy.size() == 0); | ||
{ | ||
weight_and_energy.push({1.0, {2.0}}); | ||
CHECK(weight_and_energy.size() == 1); | ||
auto avg = weight_and_energy.weighted_avg(); | ||
CHECK(Approx(avg[0]) == 2.0); | ||
} | ||
{ | ||
weight_and_energy.push({3.0, {1.0}}); | ||
CHECK(weight_and_energy.size() == 2); | ||
auto avg = weight_and_energy.weighted_avg(); | ||
CHECK(Approx(avg[0]) == 1.25); | ||
} | ||
{ | ||
SizeLimitedDataQueue<double, 1>::HistoryElement temp{0.5, {3.0}}; | ||
weight_and_energy.push(std::move(temp)); | ||
CHECK(weight_and_energy.size() == 3); | ||
auto avg = weight_and_energy.weighted_avg(); | ||
CHECK(Approx(avg[0]) == 1.444444444); | ||
} | ||
{ | ||
weight_and_energy.push({0.5, {3.0}}); | ||
CHECK(weight_and_energy.size() == 3); | ||
auto avg = weight_and_energy.weighted_avg(); | ||
CHECK(Approx(avg[0]) == 1.5); | ||
} | ||
} | ||
|
||
} // namespace qmcplusplus |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
////////////////////////////////////////////////////////////////////////////////////// | ||
// This file is distributed under the University of Illinois/NCSA Open Source License. | ||
// See LICENSE file in top directory for details. | ||
// | ||
// Copyright (c) 2023 QMCPACK developers. | ||
// | ||
// File developed by: Ye Luo, [email protected], Argonne National Laboratory | ||
// | ||
// File created by: Ye Luo, [email protected], Argonne National Laboratory | ||
////////////////////////////////////////////////////////////////////////////////////// | ||
|
||
#include <cassert> | ||
#include "DMCRefEnergy.h" | ||
|
||
namespace qmcplusplus | ||
{ | ||
|
||
using FullPrecReal = DMCRefEnergy::FullPrecReal; | ||
|
||
DMCRefEnergy::DMCRefEnergy(DMCRefEnergyScheme scheme, size_t history_limit) | ||
: scheme_(scheme), energy_and_variance_(history_limit) | ||
{} | ||
|
||
std::tuple<FullPrecReal, FullPrecReal> DMCRefEnergy::getEnergyVariance() const | ||
{ | ||
if (scheme_ == DMCRefEnergyScheme::LIMITED_HISTORY) | ||
{ | ||
auto avg = energy_and_variance_.weighted_avg(); | ||
return {avg[ENERGY], avg[VARIANCE]}; | ||
} | ||
else | ||
return {energy_hist_.mean(), variance_hist_.mean()}; | ||
} | ||
|
||
void DMCRefEnergy::pushWeightEnergyVariance(FullPrecReal weight, FullPrecReal ene, FullPrecReal var) | ||
{ | ||
if (scheme_ == DMCRefEnergyScheme::LIMITED_HISTORY) | ||
energy_and_variance_.push({weight, {ene, var}}); | ||
else | ||
{ | ||
energy_hist_(ene); | ||
variance_hist_(var); | ||
} | ||
} | ||
|
||
size_t DMCRefEnergy::count() const | ||
{ | ||
if (scheme_ == DMCRefEnergyScheme::LIMITED_HISTORY) | ||
return energy_and_variance_.size(); | ||
else | ||
{ | ||
assert(energy_hist_.count() == variance_hist_.count()); | ||
return energy_hist_.count(); | ||
} | ||
} | ||
|
||
} // namespace qmcplusplus |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
////////////////////////////////////////////////////////////////////////////////////// | ||
// This file is distributed under the University of Illinois/NCSA Open Source License. | ||
// See LICENSE file in top directory for details. | ||
// | ||
// Copyright (c) 2023 QMCPACK developers. | ||
// | ||
// File developed by: Ye Luo, [email protected], Argonne National Laboratory | ||
// | ||
// File created by: Ye Luo, [email protected], Argonne National Laboratory | ||
////////////////////////////////////////////////////////////////////////////////////// | ||
// -*- C++ -*- | ||
#ifndef QMCPLUSPLUS_DMCREFENERGY_H | ||
#define QMCPLUSPLUS_DMCREFENERGY_H | ||
|
||
#include <tuple> | ||
#include <Configuration.h> | ||
#include <Estimators/SizeLimitedDataQueue.hpp> | ||
#include <Estimators/accumulators.h> | ||
#include "DMCRefEnergyScheme.h" | ||
|
||
namespace qmcplusplus | ||
{ | ||
/** Handle updating Eref used for calculating the trial energy. | ||
*/ | ||
class DMCRefEnergy | ||
{ | ||
public: | ||
using FullPrecReal = QMCTraits::FullPrecRealType; | ||
|
||
enum DataLayout | ||
{ | ||
ENERGY = 0, | ||
VARIANCE, | ||
DATA_SIZE | ||
}; | ||
|
||
private: | ||
/// scheme | ||
DMCRefEnergyScheme scheme_; | ||
|
||
// legacy scheme data | ||
///a simple accumulator for energy | ||
accumulator_set<FullPrecReal> energy_hist_; | ||
///a simple accumulator for variance | ||
accumulator_set<FullPrecReal> variance_hist_; | ||
|
||
// limited memory scheme data | ||
SizeLimitedDataQueue<FullPrecReal, DataLayout::DATA_SIZE> energy_and_variance_; | ||
|
||
public: | ||
DMCRefEnergy(DMCRefEnergyScheme scheme, size_t history_limit); | ||
|
||
/// return energy and variance | ||
std::tuple<FullPrecReal, FullPrecReal> getEnergyVariance() const; | ||
|
||
/// record weight, energy and variance. | ||
void pushWeightEnergyVariance(FullPrecReal weight, FullPrecReal ene, FullPrecReal var); | ||
|
||
/// return record count. | ||
size_t count() const; | ||
}; | ||
|
||
} // namespace qmcplusplus | ||
#endif |
Oops, something went wrong.