Skip to content

Commit

Permalink
Remove test differences and fix build
Browse files Browse the repository at this point in the history
Signed-off-by: Steven Hahn <[email protected]>
  • Loading branch information
quantumsteve authored and williamfgc committed Nov 8, 2023
1 parent cb3d07c commit 06e69d0
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/QMCWaveFunctions/LCAO/LCAOrbitalBuilderT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@ typename LCAOrbitalBuilderT<T>::BasisSet_t* LCAOrbitalBuilderT<T>::createBasisSe
return mBasisSet;
}
#ifndef QMC_COMPLEX
#ifndef MIXED_PRECISION
template<>
std::unique_ptr<SPOSetT<double>> LCAOrbitalBuilderT<double>::createWithCuspCorrection(
xmlNodePtr cur,
Expand All @@ -484,7 +485,6 @@ std::unique_ptr<SPOSetT<double>> LCAOrbitalBuilderT<double>::createWithCuspCorre
lcwc->setOrbitalSetSize(lcwc->lcao.getOrbitalSetSize());
sposet = std::move(lcwc);
}
#ifndef MIXED_PRECISION
// Create a temporary particle set to use for cusp initialization.
// The particle coordinates left at the end are unsuitable for further
// computations. The coordinates get set to nuclear positions, which
Expand Down
56 changes: 48 additions & 8 deletions src/QMCWaveFunctions/LCAO/LCAOrbitalSetT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,41 @@ void LCAOrbitalSetT<T>::mw_evaluateVGLImplGEMM(const RefVectorWithLeader<SPOSetT
}
}

template<typename T>
void LCAOrbitalSetT<T>::mw_evaluateValueVPsImplGEMM(const RefVectorWithLeader<SPOSetT<T>>& spo_list,
const RefVectorWithLeader<const VirtualParticleSetT<T>>& vp_list,
OffloadMWVArray& vp_phi_v) const
{
assert(this == &spo_list.getLeader());
auto& spo_leader = spo_list.template getCastedLeader<LCAOrbitalSetT<T>>();
//const size_t nw = spo_list.size();
auto& vp_basis_v_mw = spo_leader.mw_mem_handle_.getResource().vp_basis_v_mw;
//Splatter basis_v
const size_t nVPs = vp_phi_v.size(0);
vp_basis_v_mw.resize(nVPs, BasisSetSize);

auto basis_list = spo_leader.extractBasisRefList(spo_list);
myBasisSet->mw_evaluateValueVPs(basis_list, vp_list, vp_basis_v_mw);
vp_basis_v_mw.updateFrom(); // TODO: remove this when gemm is implemented

if (Identity)
{
std::copy_n(vp_basis_v_mw.data_at(0, 0), this->OrbitalSetSize * nVPs, vp_phi_v.data_at(0, 0));
}
else
{
const size_t requested_orb_size = vp_phi_v.size(1);
assert(requested_orb_size <= this->OrbitalSetSize);
ValueMatrix C_partial_view(C->data(), requested_orb_size, BasisSetSize);
BLAS::gemm('T', 'N',
requested_orb_size, // MOs
nVPs, // walkers * Virtual Particles
BasisSetSize, // AOs
1, C_partial_view.data(), BasisSetSize, vp_basis_v_mw.data(), BasisSetSize, 0, vp_phi_v.data(),
requested_orb_size);
}
}

template<class T>
void LCAOrbitalSetT<T>::mw_evaluateValue(const RefVectorWithLeader<SPOSetT<T>>& spo_list,
const RefVectorWithLeader<ParticleSetT<T>>& P_list,
Expand Down Expand Up @@ -579,15 +614,20 @@ void LCAOrbitalSetT<T>::mw_evaluateDetRatios(const RefVectorWithLeader<SPOSetT<T
const std::vector<const T*>& invRow_ptr_list,
std::vector<std::vector<T>>& ratios_list) const
{
const size_t nw = spo_list.size();
for (size_t iw = 0; iw < nw; iw++)
{
assert(this == &spo_list.getLeader());
auto& spo_leader = spo_list.template getCastedLeader<LCAOrbitalSetT<T>>();
auto& vp_phi_v = spo_leader.mw_mem_handle_.getResource().vp_phi_v;

const size_t nVPs = VirtualParticleSetT<T>::countVPs(vp_list);
const size_t requested_orb_size = psi_list[0].get().size();
vp_phi_v.resize(nVPs, requested_orb_size);

mw_evaluateValueVPsImplGEMM(spo_list, vp_list, vp_phi_v);

size_t index = 0;
for (size_t iw = 0; iw < vp_list.size(); iw++)
for (size_t iat = 0; iat < vp_list[iw].getTotalNum(); iat++)
{
spo_list[iw].evaluateValue(vp_list[iw], iat, psi_list[iw]);
ratios_list[iw][iat] = simd::dot(psi_list[iw].get().data(), invRow_ptr_list[iw], psi_list[iw].get().size());
}
}
ratios_list[iw][iat] = simd::dot(vp_phi_v.data_at(index++, 0), invRow_ptr_list[iw], requested_orb_size);
}

template<class T>
Expand Down
5 changes: 5 additions & 0 deletions src/QMCWaveFunctions/LCAO/LCAOrbitalSetT.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,11 @@ class LCAOrbitalSetT : public SPOSetT<T>
int iat,
OffloadMWVArray& phi_v) const;

/// packed walker GEMM implementation with multi virtual particle sets
void mw_evaluateValueVPsImplGEMM(const RefVectorWithLeader<SPOSetT<T>>& spo_list,
const RefVectorWithLeader<const VirtualParticleSetT<T>>& vp_list,
OffloadMWVArray& phi_v) const;

/// helper function for extracting a list of basis sets from a list of LCAOrbitalSet
RefVectorWithLeader<basis_type> extractBasisRefList(const RefVectorWithLeader<SPOSetT<T>>& spo_list) const;

Expand Down
8 changes: 4 additions & 4 deletions src/QMCWaveFunctions/tests/test_LCAO_diamondC_2x1x1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
#include "DistanceTable.h"
#include "QMCWaveFunctions/SPOSet.h"
#include "QMCWaveFunctions/LCAO/LCAOrbitalSet.h"
#include <stdio.h>

#include <cstdio>
#include <string>
#include <limits>

Expand Down Expand Up @@ -338,10 +339,9 @@ void test_LCAO_DiamondC_2x1x1_real()
ratios_list[iw].resize(nvp_list[iw]);

// just need dummy refvec with correct size
SPOSet::ValueVector tmp_psi_list(norb), tmp_psi_list_2(norb);
SPOSet::ValueVector tmp_psi_list(norb);
spo->mw_evaluateDetRatios(spo_list, RefVectorWithLeader<const VirtualParticleSet>(VP_, {VP_, VP_2}),
RefVector<SPOSet::ValueVector>{tmp_psi_list, tmp_psi_list_2}, invRow_ptr_list,
ratios_list);
RefVector<SPOSet::ValueVector>{tmp_psi_list}, invRow_ptr_list, ratios_list);

std::vector<SPOSet::ValueType> ratios_ref_0(nvp_);
std::vector<SPOSet::ValueType> ratios_ref_1(nvp_2);
Expand Down

0 comments on commit 06e69d0

Please sign in to comment.