Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched driver move abstraction #3762

Merged
merged 40 commits into from
Feb 17, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
a6d3e61
add abstraction layer in batched qmc
camelto2 Jan 21, 2022
b1c36c0
add spin moves to abstraction
camelto2 Jan 24, 2022
c7a11ff
add selection in VMCBatched
camelto2 Jan 24, 2022
5b8cd85
add advanceWalkers<POSITIONS_SPINS> to vmcbatched
camelto2 Jan 24, 2022
b71e250
add DMC move abstraction
camelto2 Jan 24, 2022
39e809a
add advanceWalkers<POSITIONS_SPINS>
camelto2 Jan 24, 2022
40dfce7
Merge remote-tracking branch 'upstream/develop' into batched_driver_m…
camelto2 Jan 25, 2022
71ef608
add comments to MoveAbtraction class
camelto2 Jan 25, 2022
f46a9bf
add docs for spin mass to batched drivers
camelto2 Jan 25, 2022
5169590
remove unused variable
camelto2 Jan 25, 2022
1f28c8b
move CoordsToMove enum to MoveAbstraction
camelto2 Jan 25, 2022
cd29a55
remove TWFdispatcher from MoveAbstraction, add elecs
camelto2 Jan 25, 2022
cdbe395
separate mw_ spatial and spin moves in ParticleSet to individual APIs
camelto2 Jan 26, 2022
9132b29
change move to mover in drivers
camelto2 Jan 26, 2022
b65d664
use golden_electrons to get spinor info
camelto2 Jan 26, 2022
7e4e77e
use golden electrons
camelto2 Jan 26, 2022
e77f2fe
Merge branch 'develop' into batched_driver_move_abstraction
camelto2 Feb 2, 2022
34b4b7a
Merge remote-tracking branch 'upstream/develop' into batched_driver_m…
camelto2 Feb 7, 2022
ea7c681
remove extra definition of function
camelto2 Feb 7, 2022
8d0eab8
change MoveAbtraction to template on CoordsType
camelto2 Feb 8, 2022
97a737c
use Taus object
camelto2 Feb 8, 2022
b10e42d
change MoveAbstraction members to use MCCoords
camelto2 Feb 9, 2022
fb1e55f
Merge remote-tracking branch 'upstream/develop' into batched_driver_m…
camelto2 Feb 9, 2022
c98ed9c
Add template version of flex/mw_makeMove.
ye-luo Feb 9, 2022
fe2b054
add template version of flex accept/reject
camelto2 Feb 9, 2022
28dc604
simplify check on CoordsType in MoveAbstraction
camelto2 Feb 9, 2022
c5991a4
add TWFGrads<CT> and template TWFdispatcher flex_ and DriftModifierBase
camelto2 Feb 9, 2022
618b262
remove MoveAbstraction
camelto2 Feb 9, 2022
2f58150
call accept_rejectSpinMove first
camelto2 Feb 9, 2022
877233d
fix failing test
camelto2 Feb 10, 2022
a6a2f15
Fix template.
ye-luo Feb 11, 2022
4e5e965
Add TauParams.hpp
ye-luo Feb 11, 2022
bda3f0b
Simplify ContextForSteps
ye-luo Feb 11, 2022
95ae489
move switch from TWFdispatcher into TrialWaveFunction
camelto2 Feb 15, 2022
41236e9
Merge pull request #1 from camelto2/move_switch_to_TWF
camelto2 Feb 16, 2022
60abdad
Update header info
ye-luo Feb 16, 2022
c2220ae
update spinor tests to use templated APIs in ParticleSet
camelto2 Feb 16, 2022
22521d2
add unit test for TWFGrads
camelto2 Feb 16, 2022
6593e7a
Merge branch 'develop' into batched_driver_move_abstraction
ye-luo Feb 16, 2022
3c7aeb6
add test_TWFGrads to cmake
camelto2 Feb 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/Particle/ParticleBase/RandomSeqGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ inline void makeGaussRandomWithEngine(ParticleAttrib<T>& a, RG& rng)
assignGaussRand(&(a[0]), a.size(), rng);
}

template<typename T, class RG>
inline void makeGaussRandomWithEngine(std::vector<T>& a, RG& rng)
{
assignGaussRand(&(a[0]), a.size(), rng);
}

} // namespace qmcplusplus


Expand Down
77 changes: 37 additions & 40 deletions src/QMCDrivers/DMC/DMCBatched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "QMCDrivers/DMC/WalkerControl.h"
#include "QMCDrivers/SFNBranch.h"
#include "MemoryUsage.h"
#include "MoveAbstraction.h"

namespace qmcplusplus
{
Expand Down Expand Up @@ -60,6 +61,7 @@ void DMCBatched::setNonLocalMoveHandler(QMCHamiltonian& golden_hamiltonian)
dmcdriver_input_.get_alpha(), dmcdriver_input_.get_gamma());
}

template<QMCDriverNew::CoordsToMove COORDS>
void DMCBatched::advanceWalkers(const StateForThread& sft,
Crowd& crowd,
DriverTimers& timers,
Expand Down Expand Up @@ -102,14 +104,13 @@ void DMCBatched::advanceWalkers(const StateForThread& sft,

const int num_walkers = crowd.size();

MoveAbstraction<COORDS> move(ps_dispatcher, twf_dispatcher, step_context.get_random_gen(), sft.drift_modifier,
num_walkers, sft.population.get_num_particles());

//This generates an entire steps worth of deltas.
step_context.nextDeltaRs(num_walkers * sft.population.get_num_particles());
auto it_delta_r = step_context.deltaRsBegin();
move.generateDeltas();

std::vector<TrialWaveFunction::GradType> grads_now(num_walkers, TrialWaveFunction::GradType(0.0));
std::vector<TrialWaveFunction::GradType> grads_new(num_walkers, TrialWaveFunction::GradType(0.0));
std::vector<TrialWaveFunction::PsiValueType> ratios(num_walkers, TrialWaveFunction::PsiValueType(0.0));
std::vector<PosType> drifts(num_walkers, 0.0);
std::vector<RealType> log_gf(num_walkers, 0.0);
std::vector<RealType> log_gb(num_walkers, 0.0);
std::vector<RealType> prob(num_walkers, 0.0);
Expand All @@ -130,19 +131,14 @@ void DMCBatched::advanceWalkers(const StateForThread& sft,
ScopedTimer pbyp_local_timer(timers.movepbyp_timer);
for (int ig = 0; ig < step_context.get_num_groups(); ++ig)
{
RealType tauovermass = sft.qmcdrv_input.get_tau() * sft.population.get_ptclgrp_inv_mass()[ig];
RealType oneover2tau = 0.5 / (tauovermass);
RealType sqrttau = std::sqrt(tauovermass);
move.setTauForGroup(sft.qmcdrv_input, sft.population.get_ptclgrp_inv_mass()[ig]);

twf_dispatcher.flex_prepareGroup(walker_twfs, walker_elecs, ig);

int start_index = step_context.getPtclGroupStart(ig);
int end_index = step_context.getPtclGroupEnd(ig);
for (int iat = start_index; iat < end_index; ++iat)
{
auto delta_r_start = it_delta_r + iat * num_walkers;
auto delta_r_end = delta_r_start + num_walkers;

//This is very useful thing to be able to look at in the debugger
#ifndef NDEBUG
std::vector<int> walkers_who_have_been_on_wire(num_walkers, 0);
Expand All @@ -152,19 +148,12 @@ void DMCBatched::advanceWalkers(const StateForThread& sft,
: walkers_who_have_been_on_wire[iw] = 0;
}
#endif
//get the displacement
twf_dispatcher.flex_evalGrad(walker_twfs, walker_elecs, iat, grads_now);
sft.drift_modifier.getDrifts(tauovermass, grads_now, drifts);

std::transform(drifts.begin(), drifts.end(), delta_r_start, drifts.begin(),
[sqrttau](PosType& drift, PosType& delta_r) { return drift + (sqrttau * delta_r); });
move.calcForwardMoveWithDrift(walker_twfs, walker_elecs, iat);

// only DMC does this
// TODO: rr needs a real name
std::vector<RealType> rr(num_walkers, 0.0);
assert(rr.size() == delta_r_end - delta_r_start);
std::transform(delta_r_start, delta_r_end, rr.begin(),
[tauovermass](auto& delta_r) { return tauovermass * dot(delta_r, delta_r); });
move.updaterr(iat, rr);

// in DMC this was done here, changed to match VMCBatched pending factoring to common source
// if (rr > m_r2max)
Expand All @@ -176,9 +165,9 @@ void DMCBatched::advanceWalkers(const StateForThread& sft,
for (int i = 0; i < rr.size(); ++i)
assert(std::isfinite(rr[i]));
#endif
ps_dispatcher.flex_makeMove(walker_elecs, iat, drifts);
move.makeMove(walker_elecs, iat);

twf_dispatcher.flex_calcRatioGrad(walker_twfs, walker_elecs, iat, ratios, grads_new);
move.updateGreensFunctionWithDrift(walker_twfs, walker_elecs, crowd, iat, ratios, log_gf, log_gb);

auto checkPhaseChanged = [&sft](const TrialWaveFunction& twf, int& is_reject) {
if (sft.branch_engine.phaseChanged(twf.getPhaseDiff()))
Expand All @@ -196,19 +185,6 @@ void DMCBatched::advanceWalkers(const StateForThread& sft,
rr_proposed[iw] += rr[iw];
}

std::transform(delta_r_start, delta_r_end, log_gf.begin(), [](auto& delta_r) {
constexpr RealType mhalf(-0.5);
return mhalf * dot(delta_r, delta_r);
});

sft.drift_modifier.getDrifts(tauovermass, grads_new, drifts);

std::transform(crowd.beginElectrons(), crowd.endElectrons(), drifts.begin(), drifts.begin(),
[iat](auto& elecs, auto& drift) { return elecs.get().R[iat] - elecs.get().getActivePos() - drift; });

std::transform(drifts.begin(), drifts.end(), log_gb.begin(),
[oneover2tau](auto& drift) { return -oneover2tau * dot(drift, drift); });

for (int iw = 0; iw < num_walkers; ++iw)
prob[iw] = std::norm(ratios[iw]) * std::exp(log_gb[iw] - log_gf[iw]);

Expand Down Expand Up @@ -324,6 +300,22 @@ void DMCBatched::advanceWalkers(const StateForThread& sft,
}
}

template void DMCBatched::advanceWalkers<QMCDriverNew::POSITIONS>(const StateForThread& sft,
Crowd& crowd,
DriverTimers& timers,
DMCTimers& dmc_timers,
ContextForSteps& step_context,
bool recompute,
bool accumulate_this_step);

template void DMCBatched::advanceWalkers<QMCDriverNew::POSITIONS_SPINS>(const StateForThread& sft,
Crowd& crowd,
DriverTimers& timers,
DMCTimers& dmc_timers,
ContextForSteps& step_context,
bool recompute,
bool accumulate_this_step);

void DMCBatched::runDMCStep(int crowd_id,
const StateForThread& sft,
DriverTimers& timers,
Expand All @@ -344,8 +336,13 @@ void DMCBatched::runDMCStep(int crowd_id,
// Are we entering the the last step of a block to recompute at?
const bool recompute_this_step = (sft.is_recomputing_block && (step + 1) == max_steps);
const bool accumulate_this_step = true;
advanceWalkers(sft, crowd, timers, dmc_timers, *context_for_steps[crowd_id], recompute_this_step,
accumulate_this_step);
const bool spin_move = crowd.get_walker_elecs()[0].get().isSpinor();
if (spin_move)
advanceWalkers<POSITIONS_SPINS>(sft, crowd, timers, dmc_timers, *context_for_steps[crowd_id], recompute_this_step,
accumulate_this_step);
else
advanceWalkers<POSITIONS>(sft, crowd, timers, dmc_timers, *context_for_steps[crowd_id], recompute_this_step,
accumulate_this_step);
}

void DMCBatched::process(xmlNodePtr node)
Expand Down Expand Up @@ -433,9 +430,9 @@ bool DMCBatched::run()
dmc_state.recalculate_properties_period = (qmc_driver_mode_[QMC_UPDATE_MODE])
? qmcdriver_input_.get_recalculate_properties_period()
: (qmcdriver_input_.get_max_blocks() + 1) * qmcdriver_input_.get_max_steps();
dmc_state.is_recomputing_block = qmcdriver_input_.get_blocks_between_recompute()
? (1 + block) % qmcdriver_input_.get_blocks_between_recompute() == 0
: false;
dmc_state.is_recomputing_block = qmcdriver_input_.get_blocks_between_recompute()
? (1 + block) % qmcdriver_input_.get_blocks_between_recompute() == 0
: false;

for (UPtr<Crowd>& crowd : crowds_)
crowd->startBlock(qmcdriver_input_.get_max_steps());
Expand Down
1 change: 1 addition & 0 deletions src/QMCDrivers/DMC/DMCBatched.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ class DMCBatched : public QMCDriverNew
///walker controller for load-balance
std::unique_ptr<WalkerControl> walker_controller_;

template<CoordsToMove COORDS>
static void advanceWalkers(const StateForThread& sft,
Crowd& crowd,
DriverTimers& timers,
Expand Down
4 changes: 4 additions & 0 deletions src/QMCDrivers/GreenFunctionModifiers/DriftModifierBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ class DriftModifierBase

virtual void getDrifts(RealType tau, const std::vector<GradType>& qf, std::vector<PosType>&) const = 0;

virtual void getDrifts(RealType tau,
const std::vector<ComplexType>& qf,
std::vector<ParticleSet::Scalar_t>&) const = 0;

virtual bool parseXML(xmlNodePtr cur) { return true; }

virtual ~DriftModifierBase() {}
Expand Down
16 changes: 13 additions & 3 deletions src/QMCDrivers/GreenFunctionModifiers/DriftModifierUNR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ void DriftModifierUNR::getDrift(RealType tau, const GradType& qf, PosType& drift
// Generally we hope that this would only occur as the result of bad input
// which would hopefully be the result of development time error and
// therefore caught when run in a Debug build.
if( std::isnan(vsq) )
if (std::isnan(vsq))
{
std::ostringstream error_message;
for(int i = 0; i < drift.size(); ++i)
for (int i = 0; i < drift.size(); ++i)
{
if (std::isnan(drift[i]))
{
Expand Down Expand Up @@ -66,7 +66,7 @@ void DriftModifierUNR::getDrift(RealType tau, const ComplexType& qf, ParticleSet
// Generally we hope that this would only occur as the result of bad input
// which would hopefully be the result of development time error and
// therefore caught when run in a Debug build.
if( std::isnan(vsq) )
if (std::isnan(vsq))
{
std::ostringstream error_message;
if (std::isnan(drift))
Expand All @@ -90,6 +90,16 @@ void DriftModifierUNR::getDrifts(RealType tau, const std::vector<GradType>& qf,
}
}

void DriftModifierUNR::getDrifts(RealType tau,
const std::vector<ComplexType>& qf,
std::vector<ParticleSet::Scalar_t>& drift) const
{
for (int i = 0; i < qf.size(); ++i)
{
getDrift(tau, qf[i], drift[i]);
}
}

bool DriftModifierUNR::parseXML(xmlNodePtr cur)
{
ParameterSet m_param;
Expand Down
4 changes: 4 additions & 0 deletions src/QMCDrivers/GreenFunctionModifiers/DriftModifierUNR.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ class DriftModifierUNR : public DriftModifierBase

void getDrift(RealType tau, const GradType& qf, PosType& drift) const final;

void getDrifts(RealType tau,
const std::vector<ComplexType>& qf,
std::vector<ParticleSet::Scalar_t>& drift) const final;

void getDrift(RealType tau, const ComplexType& qf, ParticleSet::Scalar_t& drift) const final;

bool parseXML(xmlNodePtr cur) final;
Expand Down
Loading