Skip to content

Commit

Permalink
Merge pull request #1 from camelto2/move_switch_to_TWF
Browse files Browse the repository at this point in the history
move switch from TWFdispatcher into TrialWaveFunction
  • Loading branch information
camelto2 authored Feb 16, 2022
2 parents bda3f0b + 95ae489 commit 41236e9
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 99 deletions.
81 changes: 12 additions & 69 deletions src/QMCWaveFunctions/TWFdispatcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,46 +77,19 @@ void TWFdispatcher::flex_evalGrad(const RefVectorWithLeader<TrialWaveFunction>&
const RefVectorWithLeader<ParticleSet>& p_list,
int iat,
TWFGrads<CT>& grads) const
{
if constexpr (CT == CoordsType::POS_SPIN)
flex_evalGradWithSpin(wf_list, p_list, iat, grads.grads_positions, grads.grads_spins);
else
flex_evalGrad(wf_list, p_list, iat, grads.grads_positions);
}

void TWFdispatcher::flex_evalGrad(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
const RefVectorWithLeader<ParticleSet>& p_list,
int iat,
std::vector<GradType>& grad_now) const
{
assert(wf_list.size() == p_list.size());
if (use_batch_)
TrialWaveFunction::mw_evalGrad(wf_list, p_list, iat, grad_now);
TrialWaveFunction::mw_evalGrad(wf_list, p_list, iat, grads);
else
{
const int num_wf = wf_list.size();
grad_now.resize(num_wf);
grads.resize(num_wf);
for (size_t iw = 0; iw < num_wf; iw++)
grad_now[iw] = wf_list[iw].evalGrad(p_list[iw], iat);
}
}

void TWFdispatcher::flex_evalGradWithSpin(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
const RefVectorWithLeader<ParticleSet>& p_list,
int iat,
std::vector<GradType>& grad_now,
std::vector<Complex>& spingrad_now) const
{
assert(wf_list.size() == p_list.size());
if (use_batch_)
TrialWaveFunction::mw_evalGradWithSpin(wf_list, p_list, iat, grad_now, spingrad_now);
else
{
const int num_wf = wf_list.size();
grad_now.resize(num_wf);
spingrad_now.resize(num_wf);
for (size_t iw = 0; iw < num_wf; iw++)
grad_now[iw] = wf_list[iw].evalGradWithSpin(p_list[iw], iat, spingrad_now[iw]);
if constexpr (CT == CoordsType::POS_SPIN)
grads.grads_positions[iw] = wf_list[iw].evalGradWithSpin(p_list[iw], iat, grads.grads_spins[iw]);
else
grads.grads_positions[iw] = wf_list[iw].evalGrad(p_list[iw], iat);
}
}

Expand All @@ -126,50 +99,20 @@ void TWFdispatcher::flex_calcRatioGrad(const RefVectorWithLeader<TrialWaveFuncti
int iat,
std::vector<PsiValueType>& ratios,
TWFGrads<CT>& grads) const
{
if constexpr (CT == CoordsType::POS_SPIN)
flex_calcRatioGradWithSpin(wf_list, p_list, iat, ratios, grads.grads_positions, grads.grads_spins);
else
flex_calcRatioGrad(wf_list, p_list, iat, ratios, grads.grads_positions);
}

void TWFdispatcher::flex_calcRatioGrad(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
const RefVectorWithLeader<ParticleSet>& p_list,
int iat,
std::vector<PsiValueType>& ratios,
std::vector<GradType>& grad_new) const
{
assert(wf_list.size() == p_list.size());
if (use_batch_)
TrialWaveFunction::mw_calcRatioGrad(wf_list, p_list, iat, ratios, grad_new);
TrialWaveFunction::mw_calcRatioGrad(wf_list, p_list, iat, ratios, grads);
else
{
const int num_wf = wf_list.size();
ratios.resize(num_wf);
grad_new.resize(num_wf);
grads.resize(num_wf);
for (size_t iw = 0; iw < num_wf; iw++)
ratios[iw] = wf_list[iw].calcRatioGrad(p_list[iw], iat, grad_new[iw]);
}
}

void TWFdispatcher::flex_calcRatioGradWithSpin(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
const RefVectorWithLeader<ParticleSet>& p_list,
int iat,
std::vector<PsiValueType>& ratios,
std::vector<GradType>& grad_new,
std::vector<Complex>& spingrad_new) const
{
assert(wf_list.size() == p_list.size());
if (use_batch_)
TrialWaveFunction::mw_calcRatioGradWithSpin(wf_list, p_list, iat, ratios, grad_new, spingrad_new);
else
{
const int num_wf = wf_list.size();
ratios.resize(num_wf);
grad_new.resize(num_wf);
spingrad_new.resize(num_wf);
for (size_t iw = 0; iw < num_wf; iw++)
ratios[iw] = wf_list[iw].calcRatioGradWithSpin(p_list[iw], iat, grad_new[iw], spingrad_new[iw]);
if constexpr (CT == CoordsType::POS_SPIN)
ratios[iw] = wf_list[iw].calcRatioGradWithSpin(p_list[iw], iat, grads.grads_positions[iw], grads.grads_spins[iw]);
else
ratios[iw] = wf_list[iw].calcRatioGrad(p_list[iw], iat, grads.grads_positions[iw]);
}
}

Expand Down
24 changes: 0 additions & 24 deletions src/QMCWaveFunctions/TWFdispatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,37 +56,13 @@ class TWFdispatcher
int iat,
TWFGrads<CT>& grads) const;

void flex_evalGrad(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
const RefVectorWithLeader<ParticleSet>& p_list,
int iat,
std::vector<GradType>& grad_now) const;

void flex_evalGradWithSpin(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
const RefVectorWithLeader<ParticleSet>& p_list,
int iat,
std::vector<GradType>& grad_now,
std::vector<Complex>& spingrad_now) const;

template<CoordsType CT>
void flex_calcRatioGrad(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
const RefVectorWithLeader<ParticleSet>& p_list,
int iat,
std::vector<PsiValueType>& ratios,
TWFGrads<CT>& grads) const;

void flex_calcRatioGrad(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
const RefVectorWithLeader<ParticleSet>& p_list,
int iat,
std::vector<PsiValueType>& ratios,
std::vector<GradType>& grad_new) const;

void flex_calcRatioGradWithSpin(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
const RefVectorWithLeader<ParticleSet>& p_list,
int iat,
std::vector<PsiValueType>& ratios,
std::vector<GradType>& grad_new,
std::vector<Complex>& spingrad_new) const;

void flex_accept_rejectMove(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
const RefVectorWithLeader<ParticleSet>& p_list,
int iat,
Expand Down
60 changes: 54 additions & 6 deletions src/QMCWaveFunctions/TrialWaveFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,12 @@ void TrialWaveFunction::mw_evaluateLog(const RefVectorWithLeader<TrialWaveFuncti
const int num_particles = p_leader.getTotalNum();
auto initGandL = [num_particles, czero](TrialWaveFunction& twf, ParticleSet::ParticleGradient& grad,
ParticleSet::ParticleLaplacian& lapl) {
grad.resize(num_particles);
lapl.resize(num_particles);
grad = czero;
lapl = czero;
twf.log_real_ = czero;
twf.PhaseValue = czero;
grad.resize(num_particles);
lapl.resize(num_particles);
grad = czero;
lapl = czero;
twf.log_real_ = czero;
twf.PhaseValue = czero;
};
for (int iw = 0; iw < wf_list.size(); iw++)
initGandL(wf_list[iw], g_list[iw], l_list[iw]);
Expand Down Expand Up @@ -527,6 +527,18 @@ TrialWaveFunction::GradType TrialWaveFunction::evalGradWithSpin(ParticleSet& P,
return grad_iat;
}

template<CoordsType CT>
void TrialWaveFunction::mw_evalGrad(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
const RefVectorWithLeader<ParticleSet>& p_list,
int iat,
TWFGrads<CT>& grads)
{
if constexpr (CT == CoordsType::POS_SPIN)
mw_evalGradWithSpin(wf_list, p_list, iat, grads.grads_positions, grads.grads_spins);
else
mw_evalGrad(wf_list, p_list, iat, grads.grads_positions);
}

void TrialWaveFunction::mw_evalGrad(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
const RefVectorWithLeader<ParticleSet>& p_list,
int iat,
Expand Down Expand Up @@ -668,6 +680,19 @@ TrialWaveFunction::ValueType TrialWaveFunction::calcRatioGradWithSpin(ParticleSe
return static_cast<ValueType>(r);
}

template<CoordsType CT>
void TrialWaveFunction::mw_calcRatioGrad(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
const RefVectorWithLeader<ParticleSet>& p_list,
int iat,
std::vector<PsiValueType>& ratios,
TWFGrads<CT>& grads)
{
if constexpr (CT == CoordsType::POS_SPIN)
mw_calcRatioGradWithSpin(wf_list, p_list, iat, ratios, grads.grads_positions, grads.grads_spins);
else
mw_calcRatioGrad(wf_list, p_list, iat, ratios, grads.grads_positions);
}

void TrialWaveFunction::mw_calcRatioGrad(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
const RefVectorWithLeader<ParticleSet>& p_list,
int iat,
Expand Down Expand Up @@ -1332,4 +1357,27 @@ void TrialWaveFunction::initializeTWFFastDerivWrapper(const ParticleSet& P, TWFF
}
}

//explicit instantiations
template void TrialWaveFunction::mw_evalGrad<CoordsType::POS>(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
const RefVectorWithLeader<ParticleSet>& p_list,
int iat,
TWFGrads<CoordsType::POS>& grads);
template void TrialWaveFunction::mw_evalGrad<CoordsType::POS_SPIN>(
const RefVectorWithLeader<TrialWaveFunction>& wf_list,
const RefVectorWithLeader<ParticleSet>& p_list,
int iat,
TWFGrads<CoordsType::POS_SPIN>& grads);
template void TrialWaveFunction::mw_calcRatioGrad<CoordsType::POS>(
const RefVectorWithLeader<TrialWaveFunction>& wf_list,
const RefVectorWithLeader<ParticleSet>& p_list,
int iat,
std::vector<PsiValueType>& ratios,
TWFGrads<CoordsType::POS>& grads);
template void TrialWaveFunction::mw_calcRatioGrad<CoordsType::POS_SPIN>(
const RefVectorWithLeader<TrialWaveFunction>& wf_list,
const RefVectorWithLeader<ParticleSet>& p_list,
int iat,
std::vector<PsiValueType>& ratios,
TWFGrads<CoordsType::POS_SPIN>& grads);

} // namespace qmcplusplus
25 changes: 25 additions & 0 deletions src/QMCWaveFunctions/TrialWaveFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "type_traits/template_types.hpp"
#include "Containers/MinimalContainers/RecordArray.hpp"
#include "QMCWaveFunctions/TWFFastDerivWrapper.h"
#include "TWFGrads.hpp"
#ifdef QMC_CUDA
#include "type_traits/CUDATypes.h"
#endif
Expand Down Expand Up @@ -348,6 +349,18 @@ class TrialWaveFunction
*/
ValueType calcRatioGradWithSpin(ParticleSet& P, int iat, GradType& grad_iat, ComplexType& spingrad_iat);

/** batched version of ratioGrad
*
* all vector sizes must match
* implements switch between normal and WithSpin version
*/
template<CoordsType CT>
static void mw_calcRatioGrad(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
const RefVectorWithLeader<ParticleSet>& p_list,
int iat,
std::vector<PsiValueType>& ratios,
TWFGrads<CT>& grads);

/** batched version of ratioGrad
*
* all vector sizes must match
Expand Down Expand Up @@ -396,6 +409,18 @@ class TrialWaveFunction
*/
GradType evalGradWithSpin(ParticleSet& P, int iat, ComplexType& spingrad);

/** batched version of evalGrad
*
* This is static because it should have no direct access
* to any TWF.
* implements switch between normal and WithSpin version
*/
template<CoordsType CT>
static void mw_evalGrad(const RefVectorWithLeader<TrialWaveFunction>& wf_list,
const RefVectorWithLeader<ParticleSet>& p_list,
int iat,
TWFGrads<CT>& grads);

/** batched version of evalGrad
*
* This is static because it should have no direct access
Expand Down

0 comments on commit 41236e9

Please sign in to comment.