Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

batched evaluation of evaluate_notranspose for SpinorSet #3726

Merged
merged 11 commits into from
Jan 13, 2022
93 changes: 93 additions & 0 deletions src/QMCWaveFunctions/SpinorSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,99 @@ void SpinorSet::evaluate_notranspose(const ParticleSet& P,
}
}

void SpinorSet::mw_evaluate_notranspose(const RefVectorWithLeader<SPOSet>& spo_list,
const RefVectorWithLeader<ParticleSet>& P_list,
int first,
int last,
const RefVector<ValueMatrix_t>& logdet_list,
const RefVector<GradMatrix_t>& dlogdet_list,
const RefVector<ValueMatrix_t>& d2logdet_list) const
{
auto& spo_leader = spo_list.getCastedLeader<SpinorSet>();
auto& P_leader = P_list.getLeader();
assert(this == &spo_leader);

IndexType nw = spo_list.size();
IndexType nelec = P_leader.getTotalNum();

SPOSet& up_spo_leader = *(spo_leader.spo_up);
SPOSet& dn_spo_leader = *(spo_leader.spo_dn);
RefVectorWithLeader<SPOSet> up_spo_list(up_spo_leader);
RefVectorWithLeader<SPOSet> dn_spo_list(dn_spo_leader);
up_spo_list.reserve(nw);
dn_spo_list.reserve(nw);

std::vector<ValueMatrix_t> mw_up_logdet, mw_dn_logdet;
std::vector<GradMatrix_t> mw_up_dlogdet, mw_dn_dlogdet;
std::vector<ValueMatrix_t> mw_up_d2logdet, mw_dn_d2logdet;
mw_up_logdet.reserve(nw);
mw_dn_logdet.reserve(nw);
mw_up_dlogdet.reserve(nw);
mw_dn_dlogdet.reserve(nw);
mw_up_d2logdet.reserve(nw);
mw_dn_d2logdet.reserve(nw);

RefVector<ValueMatrix_t> up_logdet_list, dn_logdet_list;
RefVector<GradMatrix_t> up_dlogdet_list, dn_dlogdet_list;
RefVector<ValueMatrix_t> up_d2logdet_list, dn_d2logdet_list;
up_logdet_list.reserve(nw);
dn_logdet_list.reserve(nw);
up_dlogdet_list.reserve(nw);
dn_dlogdet_list.reserve(nw);
up_d2logdet_list.reserve(nw);
dn_d2logdet_list.reserve(nw);

ValueMatrix_t tmp_val_mat(nelec, OrbitalSetSize);
GradMatrix_t tmp_grad_mat(nelec, OrbitalSetSize);
for (int iw = 0; iw < nw; iw++)
{
SpinorSet& spinor = spo_list.getCastedElement<SpinorSet>(iw);
up_spo_list.emplace_back(*(spinor.spo_up));
dn_spo_list.emplace_back(*(spinor.spo_dn));

mw_up_logdet.emplace_back(tmp_val_mat);
up_logdet_list.emplace_back(mw_up_logdet.back());
mw_dn_logdet.emplace_back(tmp_val_mat);
dn_logdet_list.emplace_back(mw_dn_logdet.back());

mw_up_dlogdet.emplace_back(tmp_grad_mat);
up_dlogdet_list.emplace_back(mw_up_dlogdet.back());
mw_dn_dlogdet.emplace_back(tmp_grad_mat);
dn_dlogdet_list.emplace_back(mw_dn_dlogdet.back());

mw_up_d2logdet.emplace_back(tmp_val_mat);
up_d2logdet_list.emplace_back(mw_up_d2logdet.back());
mw_dn_d2logdet.emplace_back(tmp_val_mat);
dn_d2logdet_list.emplace_back(mw_dn_d2logdet.back());
}

up_spo_leader.mw_evaluate_notranspose(up_spo_list, P_list, first, last, up_logdet_list, up_dlogdet_list,
up_d2logdet_list);
dn_spo_leader.mw_evaluate_notranspose(dn_spo_list, P_list, first, last, dn_logdet_list, dn_dlogdet_list,
dn_d2logdet_list);

#pragma omp parallel for
for (int iw = 0; iw < nw; iw++)
for (int iat = 0; iat < nelec; iat++)
{
ParticleSet::Scalar_t s = P_list[iw].activeSpin(iat);
RealType coss = std::cos(s);
RealType sins = std::sin(s);
ValueType eis(coss, sins);
ValueType emis(coss, -sins);

for (int no = 0; no < OrbitalSetSize; no++)
{
logdet_list[iw].get()(iat, no) =
eis * up_logdet_list[iw].get()(iat, no) + emis * dn_logdet_list[iw].get()(iat, no);
dlogdet_list[iw].get()(iat, no) =
eis * up_dlogdet_list[iw].get()(iat, no) + emis * dn_dlogdet_list[iw].get()(iat, no);
d2logdet_list[iw].get()(iat, no) =
eis * up_d2logdet_list[iw].get()(iat, no) + emis * dn_d2logdet_list[iw].get()(iat, no);
}
}
}

void SpinorSet::evaluate_notranspose_spin(const ParticleSet& P,
int first,
int last,
Expand Down
8 changes: 8 additions & 0 deletions src/QMCWaveFunctions/SpinorSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@ class SpinorSet : public SPOSet
GradMatrix_t& dlogdet,
ValueMatrix_t& d2logdet) override;

void mw_evaluate_notranspose(const RefVectorWithLeader<SPOSet>& spo_list,
const RefVectorWithLeader<ParticleSet>& P_list,
int first,
int last,
const RefVector<ValueMatrix_t>& logdet_list,
const RefVector<GradMatrix_t>& dlogdet_list,
const RefVector<ValueMatrix_t>& d2logdet_list) const override;

void evaluate_notranspose_spin(const ParticleSet& P,
int first,
int last,
Expand Down
89 changes: 89 additions & 0 deletions src/QMCWaveFunctions/tests/test_einset_spinor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "Particle/ParticleSetPool.h"
#include "QMCWaveFunctions/WaveFunctionComponent.h"
#include "QMCWaveFunctions/SPOSetBuilderFactory.h"
#include "Utilities/ResourceCollection.h"

#include <stdio.h>
#include <string>
Expand Down Expand Up @@ -435,7 +436,95 @@ TEST_CASE("Einspline SpinorSet from HDF", "[wavefunction]")

elec_.rejectMove(iat);
}


// test batched interface
// first move elec_ back to original positions for reference
Rnew = elec_.R - dR;
elec_.R = Rnew;
elec_.update();

//now create second walker, with permuted particle positions
ParticleSet elec_2(elec_);
// permute electrons
elec_2.R[0] = elec_.R[1];
elec_2.R[1] = elec_.R[2];
elec_2.R[2] = elec_.R[0];
elec_2.spins[0] = elec_.spins[1];
elec_2.spins[1] = elec_.spins[2];
elec_2.spins[2] = elec_.spins[0];

ResourceCollection pset_res("test_pset_res");
elec_.createResource(pset_res);

RefVectorWithLeader<ParticleSet> p_list(elec_);
p_list.push_back(elec_);
p_list.push_back(elec_2);

ResourceCollectionTeamLock<ParticleSet> mw_pset_lock(pset_res, p_list);

//update all walkers
elec_.mw_update(p_list);

std::unique_ptr<SPOSet> spo_2(spo->makeClone());
RefVectorWithLeader<SPOSet> spo_list(*spo);
spo_list.push_back(*spo);
spo_list.push_back(*spo_2);

ye-luo marked this conversation as resolved.
Show resolved Hide resolved
SPOSet::ValueMatrix_t psiM_2(elec_.R.size(), spo->getOrbitalSetSize());
SPOSet::GradMatrix_t dpsiM_2(elec_.R.size(), spo->getOrbitalSetSize());
SPOSet::ValueMatrix_t d2psiM_2(elec_.R.size(), spo->getOrbitalSetSize());

RefVector<SPOSet::ValueMatrix_t> logdet_list;
RefVector<SPOSet::GradMatrix_t> dlogdet_list;
RefVector<SPOSet::ValueMatrix_t> d2logdet_list;

logdet_list.push_back(psiM);
logdet_list.push_back(psiM_2);
dlogdet_list.push_back(dpsiM);
dlogdet_list.push_back(dpsiM_2);
d2logdet_list.push_back(d2psiM);
d2logdet_list.push_back(d2psiM_2);

spo->mw_evaluate_notranspose(spo_list, p_list, 0, 3, logdet_list, dlogdet_list, d2logdet_list);
for (unsigned int iat = 0; iat < 3; iat++)
{
//walker 0
CHECK(logdet_list[0].get()[iat][0] == ComplexApprox(psiM_ref[iat][0]).epsilon(h));
CHECK(logdet_list[0].get()[iat][1] == ComplexApprox(psiM_ref[iat][1]).epsilon(h));
CHECK(logdet_list[0].get()[iat][2] == ComplexApprox(psiM_ref[iat][2]).epsilon(h));
CHECK(dlogdet_list[0].get()[iat][0][0] == ComplexApprox(dpsiM_ref[iat][0][0]).epsilon(h));
CHECK(dlogdet_list[0].get()[iat][0][1] == ComplexApprox(dpsiM_ref[iat][0][1]).epsilon(h));
CHECK(dlogdet_list[0].get()[iat][0][2] == ComplexApprox(dpsiM_ref[iat][0][2]).epsilon(h));
CHECK(dlogdet_list[0].get()[iat][1][0] == ComplexApprox(dpsiM_ref[iat][1][0]).epsilon(h));
CHECK(dlogdet_list[0].get()[iat][1][1] == ComplexApprox(dpsiM_ref[iat][1][1]).epsilon(h));
CHECK(dlogdet_list[0].get()[iat][1][2] == ComplexApprox(dpsiM_ref[iat][1][2]).epsilon(h));
CHECK(dlogdet_list[0].get()[iat][2][0] == ComplexApprox(dpsiM_ref[iat][2][0]).epsilon(h));
CHECK(dlogdet_list[0].get()[iat][2][1] == ComplexApprox(dpsiM_ref[iat][2][1]).epsilon(h));
CHECK(dlogdet_list[0].get()[iat][2][2] == ComplexApprox(dpsiM_ref[iat][2][2]).epsilon(h));
CHECK(d2logdet_list[0].get()[iat][0] == ComplexApprox(d2psiM_ref[iat][0]).epsilon(h2));
CHECK(d2logdet_list[0].get()[iat][1] == ComplexApprox(d2psiM_ref[iat][1]).epsilon(h2));
CHECK(d2logdet_list[0].get()[iat][2] == ComplexApprox(d2psiM_ref[iat][2]).epsilon(h2));

//walker 1, permuted from reference
CHECK(logdet_list[1].get()[iat][0] == ComplexApprox(psiM_ref[(iat + 1) % 3][0]).epsilon(h));
CHECK(logdet_list[1].get()[iat][1] == ComplexApprox(psiM_ref[(iat + 1) % 3][1]).epsilon(h));
CHECK(logdet_list[1].get()[iat][2] == ComplexApprox(psiM_ref[(iat + 1) % 3][2]).epsilon(h));
CHECK(dlogdet_list[1].get()[iat][0][0] == ComplexApprox(dpsiM_ref[(iat+1) % 3][0][0]).epsilon(h));
CHECK(dlogdet_list[1].get()[iat][0][1] == ComplexApprox(dpsiM_ref[(iat+1) % 3][0][1]).epsilon(h));
CHECK(dlogdet_list[1].get()[iat][0][2] == ComplexApprox(dpsiM_ref[(iat+1) % 3][0][2]).epsilon(h));
CHECK(dlogdet_list[1].get()[iat][1][0] == ComplexApprox(dpsiM_ref[(iat+1) % 3][1][0]).epsilon(h));
CHECK(dlogdet_list[1].get()[iat][1][1] == ComplexApprox(dpsiM_ref[(iat+1) % 3][1][1]).epsilon(h));
CHECK(dlogdet_list[1].get()[iat][1][2] == ComplexApprox(dpsiM_ref[(iat+1) % 3][1][2]).epsilon(h));
CHECK(dlogdet_list[1].get()[iat][2][0] == ComplexApprox(dpsiM_ref[(iat+1) % 3][2][0]).epsilon(h));
CHECK(dlogdet_list[1].get()[iat][2][1] == ComplexApprox(dpsiM_ref[(iat+1) % 3][2][1]).epsilon(h));
CHECK(dlogdet_list[1].get()[iat][2][2] == ComplexApprox(dpsiM_ref[(iat+1) % 3][2][2]).epsilon(h));
CHECK(d2logdet_list[1].get()[iat][0] == ComplexApprox(d2psiM_ref[(iat + 1) % 3][0]).epsilon(h2));
CHECK(d2logdet_list[1].get()[iat][1] == ComplexApprox(d2psiM_ref[(iat + 1) % 3][1]).epsilon(h2));
CHECK(d2logdet_list[1].get()[iat][2] == ComplexApprox(d2psiM_ref[(iat + 1) % 3][2]).epsilon(h2));
}
}

#endif //QMC_COMPLEX


Expand Down