diff --git a/src/QMCWaveFunctions/LCAO/LCAOrbitalBuilderT.cpp b/src/QMCWaveFunctions/LCAO/LCAOrbitalBuilderT.cpp index 84b2b3d9a8..bdfee067d1 100644 --- a/src/QMCWaveFunctions/LCAO/LCAOrbitalBuilderT.cpp +++ b/src/QMCWaveFunctions/LCAO/LCAOrbitalBuilderT.cpp @@ -468,6 +468,7 @@ typename LCAOrbitalBuilderT::BasisSet_t* LCAOrbitalBuilderT::createBasisSe return mBasisSet; } #ifndef QMC_COMPLEX +#ifndef MIXED_PRECISION template<> std::unique_ptr> LCAOrbitalBuilderT::createWithCuspCorrection( xmlNodePtr cur, @@ -484,7 +485,6 @@ std::unique_ptr> LCAOrbitalBuilderT::createWithCuspCorre lcwc->setOrbitalSetSize(lcwc->lcao.getOrbitalSetSize()); sposet = std::move(lcwc); } -#ifndef MIXED_PRECISION // Create a temporary particle set to use for cusp initialization. // The particle coordinates left at the end are unsuitable for further // computations. The coordinates get set to nuclear positions, which diff --git a/src/QMCWaveFunctions/LCAO/LCAOrbitalSetT.cpp b/src/QMCWaveFunctions/LCAO/LCAOrbitalSetT.cpp index 99c91ead30..7ee784d3a6 100644 --- a/src/QMCWaveFunctions/LCAO/LCAOrbitalSetT.cpp +++ b/src/QMCWaveFunctions/LCAO/LCAOrbitalSetT.cpp @@ -519,6 +519,41 @@ void LCAOrbitalSetT::mw_evaluateVGLImplGEMM(const RefVectorWithLeader +void LCAOrbitalSetT::mw_evaluateValueVPsImplGEMM(const RefVectorWithLeader>& spo_list, + const RefVectorWithLeader>& vp_list, + OffloadMWVArray& vp_phi_v) const +{ + assert(this == &spo_list.getLeader()); + auto& spo_leader = spo_list.template getCastedLeader>(); + //const size_t nw = spo_list.size(); + auto& vp_basis_v_mw = spo_leader.mw_mem_handle_.getResource().vp_basis_v_mw; + //Splatter basis_v + const size_t nVPs = vp_phi_v.size(0); + vp_basis_v_mw.resize(nVPs, BasisSetSize); + + auto basis_list = spo_leader.extractBasisRefList(spo_list); + myBasisSet->mw_evaluateValueVPs(basis_list, vp_list, vp_basis_v_mw); + vp_basis_v_mw.updateFrom(); // TODO: remove this when gemm is implemented + + if (Identity) + { + std::copy_n(vp_basis_v_mw.data_at(0, 0), this->OrbitalSetSize * nVPs, vp_phi_v.data_at(0, 0)); + } + else + { + const size_t requested_orb_size = vp_phi_v.size(1); + assert(requested_orb_size <= this->OrbitalSetSize); + ValueMatrix C_partial_view(C->data(), requested_orb_size, BasisSetSize); + BLAS::gemm('T', 'N', + requested_orb_size, // MOs + nVPs, // walkers * Virtual Particles + BasisSetSize, // AOs + 1, C_partial_view.data(), BasisSetSize, vp_basis_v_mw.data(), BasisSetSize, 0, vp_phi_v.data(), + requested_orb_size); + } +} + template void LCAOrbitalSetT::mw_evaluateValue(const RefVectorWithLeader>& spo_list, const RefVectorWithLeader>& P_list, @@ -579,15 +614,20 @@ void LCAOrbitalSetT::mw_evaluateDetRatios(const RefVectorWithLeader& invRow_ptr_list, std::vector>& ratios_list) const { - const size_t nw = spo_list.size(); - for (size_t iw = 0; iw < nw; iw++) - { + assert(this == &spo_list.getLeader()); + auto& spo_leader = spo_list.template getCastedLeader>(); + auto& vp_phi_v = spo_leader.mw_mem_handle_.getResource().vp_phi_v; + + const size_t nVPs = VirtualParticleSetT::countVPs(vp_list); + const size_t requested_orb_size = psi_list[0].get().size(); + vp_phi_v.resize(nVPs, requested_orb_size); + + mw_evaluateValueVPsImplGEMM(spo_list, vp_list, vp_phi_v); + + size_t index = 0; + for (size_t iw = 0; iw < vp_list.size(); iw++) for (size_t iat = 0; iat < vp_list[iw].getTotalNum(); iat++) - { - spo_list[iw].evaluateValue(vp_list[iw], iat, psi_list[iw]); - ratios_list[iw][iat] = simd::dot(psi_list[iw].get().data(), invRow_ptr_list[iw], psi_list[iw].get().size()); - } - } + ratios_list[iw][iat] = simd::dot(vp_phi_v.data_at(index++, 0), invRow_ptr_list[iw], requested_orb_size); } template diff --git a/src/QMCWaveFunctions/LCAO/LCAOrbitalSetT.h b/src/QMCWaveFunctions/LCAO/LCAOrbitalSetT.h index 24e979595c..a569e57e5e 100644 --- a/src/QMCWaveFunctions/LCAO/LCAOrbitalSetT.h +++ b/src/QMCWaveFunctions/LCAO/LCAOrbitalSetT.h @@ -358,6 +358,11 @@ class LCAOrbitalSetT : public SPOSetT int iat, OffloadMWVArray& phi_v) const; + /// packed walker GEMM implementation with multi virtual particle sets + void mw_evaluateValueVPsImplGEMM(const RefVectorWithLeader>& spo_list, + const RefVectorWithLeader>& vp_list, + OffloadMWVArray& phi_v) const; + /// helper function for extracting a list of basis sets from a list of LCAOrbitalSet RefVectorWithLeader extractBasisRefList(const RefVectorWithLeader>& spo_list) const; diff --git a/src/QMCWaveFunctions/tests/test_LCAO_diamondC_2x1x1.cpp b/src/QMCWaveFunctions/tests/test_LCAO_diamondC_2x1x1.cpp index 427a41b88f..f67b775b49 100644 --- a/src/QMCWaveFunctions/tests/test_LCAO_diamondC_2x1x1.cpp +++ b/src/QMCWaveFunctions/tests/test_LCAO_diamondC_2x1x1.cpp @@ -23,7 +23,8 @@ #include "DistanceTable.h" #include "QMCWaveFunctions/SPOSet.h" #include "QMCWaveFunctions/LCAO/LCAOrbitalSet.h" -#include + +#include #include #include @@ -338,10 +339,9 @@ void test_LCAO_DiamondC_2x1x1_real() ratios_list[iw].resize(nvp_list[iw]); // just need dummy refvec with correct size - SPOSet::ValueVector tmp_psi_list(norb), tmp_psi_list_2(norb); + SPOSet::ValueVector tmp_psi_list(norb); spo->mw_evaluateDetRatios(spo_list, RefVectorWithLeader(VP_, {VP_, VP_2}), - RefVector{tmp_psi_list, tmp_psi_list_2}, invRow_ptr_list, - ratios_list); + RefVector{tmp_psi_list}, invRow_ptr_list, ratios_list); std::vector ratios_ref_0(nvp_); std::vector ratios_ref_1(nvp_2);