diff --git a/src/QMCWaveFunctions/TWFdispatcher.cpp b/src/QMCWaveFunctions/TWFdispatcher.cpp index 89a0414802..79ce0b5907 100644 --- a/src/QMCWaveFunctions/TWFdispatcher.cpp +++ b/src/QMCWaveFunctions/TWFdispatcher.cpp @@ -77,46 +77,19 @@ void TWFdispatcher::flex_evalGrad(const RefVectorWithLeader& const RefVectorWithLeader& p_list, int iat, TWFGrads& grads) const -{ - if constexpr (CT == CoordsType::POS_SPIN) - flex_evalGradWithSpin(wf_list, p_list, iat, grads.grads_positions, grads.grads_spins); - else - flex_evalGrad(wf_list, p_list, iat, grads.grads_positions); -} - -void TWFdispatcher::flex_evalGrad(const RefVectorWithLeader& wf_list, - const RefVectorWithLeader& p_list, - int iat, - std::vector& grad_now) const { assert(wf_list.size() == p_list.size()); if (use_batch_) - TrialWaveFunction::mw_evalGrad(wf_list, p_list, iat, grad_now); + TrialWaveFunction::mw_evalGrad(wf_list, p_list, iat, grads); else { const int num_wf = wf_list.size(); - grad_now.resize(num_wf); + grads.resize(num_wf); for (size_t iw = 0; iw < num_wf; iw++) - grad_now[iw] = wf_list[iw].evalGrad(p_list[iw], iat); - } -} - -void TWFdispatcher::flex_evalGradWithSpin(const RefVectorWithLeader& wf_list, - const RefVectorWithLeader& p_list, - int iat, - std::vector& grad_now, - std::vector& spingrad_now) const -{ - assert(wf_list.size() == p_list.size()); - if (use_batch_) - TrialWaveFunction::mw_evalGradWithSpin(wf_list, p_list, iat, grad_now, spingrad_now); - else - { - const int num_wf = wf_list.size(); - grad_now.resize(num_wf); - spingrad_now.resize(num_wf); - for (size_t iw = 0; iw < num_wf; iw++) - grad_now[iw] = wf_list[iw].evalGradWithSpin(p_list[iw], iat, spingrad_now[iw]); + if constexpr (CT == CoordsType::POS_SPIN) + grads.grads_positions[iw] = wf_list[iw].evalGradWithSpin(p_list[iw], iat, grads.grads_spins[iw]); + else + grads.grads_positions[iw] = wf_list[iw].evalGrad(p_list[iw], iat); } } @@ -126,50 +99,20 @@ void TWFdispatcher::flex_calcRatioGrad(const RefVectorWithLeader& ratios, TWFGrads& grads) const -{ - if constexpr (CT == CoordsType::POS_SPIN) - flex_calcRatioGradWithSpin(wf_list, p_list, iat, ratios, grads.grads_positions, grads.grads_spins); - else - flex_calcRatioGrad(wf_list, p_list, iat, ratios, grads.grads_positions); -} - -void TWFdispatcher::flex_calcRatioGrad(const RefVectorWithLeader& wf_list, - const RefVectorWithLeader& p_list, - int iat, - std::vector& ratios, - std::vector& grad_new) const { assert(wf_list.size() == p_list.size()); if (use_batch_) - TrialWaveFunction::mw_calcRatioGrad(wf_list, p_list, iat, ratios, grad_new); + TrialWaveFunction::mw_calcRatioGrad(wf_list, p_list, iat, ratios, grads); else { const int num_wf = wf_list.size(); ratios.resize(num_wf); - grad_new.resize(num_wf); + grads.resize(num_wf); for (size_t iw = 0; iw < num_wf; iw++) - ratios[iw] = wf_list[iw].calcRatioGrad(p_list[iw], iat, grad_new[iw]); - } -} - -void TWFdispatcher::flex_calcRatioGradWithSpin(const RefVectorWithLeader& wf_list, - const RefVectorWithLeader& p_list, - int iat, - std::vector& ratios, - std::vector& grad_new, - std::vector& spingrad_new) const -{ - assert(wf_list.size() == p_list.size()); - if (use_batch_) - TrialWaveFunction::mw_calcRatioGradWithSpin(wf_list, p_list, iat, ratios, grad_new, spingrad_new); - else - { - const int num_wf = wf_list.size(); - ratios.resize(num_wf); - grad_new.resize(num_wf); - spingrad_new.resize(num_wf); - for (size_t iw = 0; iw < num_wf; iw++) - ratios[iw] = wf_list[iw].calcRatioGradWithSpin(p_list[iw], iat, grad_new[iw], spingrad_new[iw]); + if constexpr (CT == CoordsType::POS_SPIN) + ratios[iw] = wf_list[iw].calcRatioGradWithSpin(p_list[iw], iat, grads.grads_positions[iw], grads.grads_spins[iw]); + else + ratios[iw] = wf_list[iw].calcRatioGrad(p_list[iw], iat, grads.grads_positions[iw]); } } diff --git a/src/QMCWaveFunctions/TWFdispatcher.h b/src/QMCWaveFunctions/TWFdispatcher.h index 961edd2eea..8a44d39be4 100644 --- a/src/QMCWaveFunctions/TWFdispatcher.h +++ b/src/QMCWaveFunctions/TWFdispatcher.h @@ -56,17 +56,6 @@ class TWFdispatcher int iat, TWFGrads& grads) const; - void flex_evalGrad(const RefVectorWithLeader& wf_list, - const RefVectorWithLeader& p_list, - int iat, - std::vector& grad_now) const; - - void flex_evalGradWithSpin(const RefVectorWithLeader& wf_list, - const RefVectorWithLeader& p_list, - int iat, - std::vector& grad_now, - std::vector& spingrad_now) const; - template void flex_calcRatioGrad(const RefVectorWithLeader& wf_list, const RefVectorWithLeader& p_list, @@ -74,19 +63,6 @@ class TWFdispatcher std::vector& ratios, TWFGrads& grads) const; - void flex_calcRatioGrad(const RefVectorWithLeader& wf_list, - const RefVectorWithLeader& p_list, - int iat, - std::vector& ratios, - std::vector& grad_new) const; - - void flex_calcRatioGradWithSpin(const RefVectorWithLeader& wf_list, - const RefVectorWithLeader& p_list, - int iat, - std::vector& ratios, - std::vector& grad_new, - std::vector& spingrad_new) const; - void flex_accept_rejectMove(const RefVectorWithLeader& wf_list, const RefVectorWithLeader& p_list, int iat, diff --git a/src/QMCWaveFunctions/TrialWaveFunction.cpp b/src/QMCWaveFunctions/TrialWaveFunction.cpp index 4d8c4f456c..c3add7761b 100644 --- a/src/QMCWaveFunctions/TrialWaveFunction.cpp +++ b/src/QMCWaveFunctions/TrialWaveFunction.cpp @@ -150,12 +150,12 @@ void TrialWaveFunction::mw_evaluateLog(const RefVectorWithLeader +void TrialWaveFunction::mw_evalGrad(const RefVectorWithLeader& wf_list, + const RefVectorWithLeader& p_list, + int iat, + TWFGrads& grads) +{ + if constexpr (CT == CoordsType::POS_SPIN) + mw_evalGradWithSpin(wf_list, p_list, iat, grads.grads_positions, grads.grads_spins); + else + mw_evalGrad(wf_list, p_list, iat, grads.grads_positions); +} + void TrialWaveFunction::mw_evalGrad(const RefVectorWithLeader& wf_list, const RefVectorWithLeader& p_list, int iat, @@ -668,6 +680,19 @@ TrialWaveFunction::ValueType TrialWaveFunction::calcRatioGradWithSpin(ParticleSe return static_cast(r); } +template +void TrialWaveFunction::mw_calcRatioGrad(const RefVectorWithLeader& wf_list, + const RefVectorWithLeader& p_list, + int iat, + std::vector& ratios, + TWFGrads& grads) +{ + if constexpr (CT == CoordsType::POS_SPIN) + mw_calcRatioGradWithSpin(wf_list, p_list, iat, ratios, grads.grads_positions, grads.grads_spins); + else + mw_calcRatioGrad(wf_list, p_list, iat, ratios, grads.grads_positions); +} + void TrialWaveFunction::mw_calcRatioGrad(const RefVectorWithLeader& wf_list, const RefVectorWithLeader& p_list, int iat, @@ -1332,4 +1357,27 @@ void TrialWaveFunction::initializeTWFFastDerivWrapper(const ParticleSet& P, TWFF } } +//explicit instantiations +template void TrialWaveFunction::mw_evalGrad(const RefVectorWithLeader& wf_list, + const RefVectorWithLeader& p_list, + int iat, + TWFGrads& grads); +template void TrialWaveFunction::mw_evalGrad( + const RefVectorWithLeader& wf_list, + const RefVectorWithLeader& p_list, + int iat, + TWFGrads& grads); +template void TrialWaveFunction::mw_calcRatioGrad( + const RefVectorWithLeader& wf_list, + const RefVectorWithLeader& p_list, + int iat, + std::vector& ratios, + TWFGrads& grads); +template void TrialWaveFunction::mw_calcRatioGrad( + const RefVectorWithLeader& wf_list, + const RefVectorWithLeader& p_list, + int iat, + std::vector& ratios, + TWFGrads& grads); + } // namespace qmcplusplus diff --git a/src/QMCWaveFunctions/TrialWaveFunction.h b/src/QMCWaveFunctions/TrialWaveFunction.h index 2f54492fcc..2b5b03bfdc 100644 --- a/src/QMCWaveFunctions/TrialWaveFunction.h +++ b/src/QMCWaveFunctions/TrialWaveFunction.h @@ -29,6 +29,7 @@ #include "type_traits/template_types.hpp" #include "Containers/MinimalContainers/RecordArray.hpp" #include "QMCWaveFunctions/TWFFastDerivWrapper.h" +#include "TWFGrads.hpp" #ifdef QMC_CUDA #include "type_traits/CUDATypes.h" #endif @@ -348,6 +349,18 @@ class TrialWaveFunction */ ValueType calcRatioGradWithSpin(ParticleSet& P, int iat, GradType& grad_iat, ComplexType& spingrad_iat); + /** batched version of ratioGrad + * + * all vector sizes must match + * implements switch between normal and WithSpin version + */ + template + static void mw_calcRatioGrad(const RefVectorWithLeader& wf_list, + const RefVectorWithLeader& p_list, + int iat, + std::vector& ratios, + TWFGrads& grads); + /** batched version of ratioGrad * * all vector sizes must match @@ -396,6 +409,18 @@ class TrialWaveFunction */ GradType evalGradWithSpin(ParticleSet& P, int iat, ComplexType& spingrad); + /** batched version of evalGrad + * + * This is static because it should have no direct access + * to any TWF. + * implements switch between normal and WithSpin version + */ + template + static void mw_evalGrad(const RefVectorWithLeader& wf_list, + const RefVectorWithLeader& p_list, + int iat, + TWFGrads& grads); + /** batched version of evalGrad * * This is static because it should have no direct access