Skip to content

Commit

Permalink
Merge pull request #4750 from ye-luo/pw-builder
Browse files Browse the repository at this point in the history
Rewrite planewave orbital set builder
  • Loading branch information
prckent authored Sep 29, 2023
2 parents 577fb82 + 439dbd0 commit 5d9bfeb
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 190 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ std::unique_ptr<SPOSet> EinsplineSetBuilder::createSPOSetFromXML(xmlNodePtr cur)
"no"); // use old spline library for high-order derivatives, e.g. needed for backflow optimization
std::string useGPU;
std::string GPUsharing = "no";
std::string spo_object_name;

ScopedTimer spo_timer_scope(createGlobalTimer("einspline::CreateSPOSetFromXML", timer_level_medium));

Expand Down Expand Up @@ -177,8 +176,6 @@ std::unique_ptr<SPOSet> EinsplineSetBuilder::createSPOSetFromXML(xmlNodePtr cur)
{
OhmmsAttributeSet oAttrib;
oAttrib.add(spinSet, "spindataset");
oAttrib.add(spo_object_name, "name");
oAttrib.add(spo_object_name, "id");
oAttrib.put(cur);
}

Expand Down
2 changes: 1 addition & 1 deletion src/QMCWaveFunctions/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ if(OHMMS_DIM MATCHES 3)
endif(HAVE_EINSPLINE)

# plane wave SPO
set(FERMION_SRCS ${FERMION_SRCS} PlaneWave/PWBasis.cpp PlaneWave/PWParameterSet.cpp PlaneWave/PWOrbitalBuilder.cpp)
set(FERMION_SRCS ${FERMION_SRCS} PlaneWave/PWBasis.cpp PlaneWave/PWParameterSet.cpp PlaneWave/PWOrbitalSetBuilder.cpp)
if(QMC_COMPLEX)
set(FERMION_SRCS ${FERMION_SRCS} PlaneWave/PWOrbitalSet.cpp)
else()
Expand Down
6 changes: 2 additions & 4 deletions src/QMCWaveFunctions/PlaneWave/PWBasis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,8 @@ int PWBasis::readbasis(hdf_archive& h5basisgroup,
h5basisgroup.read(gvecs, "/electrons/kpoint_0/gvectors");
NumPlaneWaves = std::max(gvecs.size(), kplusgvecs_cart.size());
if (NumPlaneWaves == 0)
{
app_error() << " PWBasis::readbasis Basis is missing. Abort " << std::endl;
abort(); //FIX_ABORT
}
throw std::runtime_error(" PWBasis::readbasis Basis is missing.");

if (kplusgvecs_cart.empty())
{
kplusgvecs_cart.resize(NumPlaneWaves);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
// This file is distributed under the University of Illinois/NCSA Open Source License.
// See LICENSE file in top directory for details.
//
// Copyright (c) 2016 Jeongnim Kim and QMCPACK developers.
// Copyright (c) 2023 QMCPACK developers.
//
// File developed by: Jeongnim Kim, [email protected], University of Illinois at Urbana-Champaign
// Jeremy McMinnis, [email protected], University of Illinois at Urbana-Champaign
// Mark A. Berrill, [email protected], Oak Ridge National Laboratory
// Mark Dewing, [email protected], University of Illinois at Urbana-Champaign
// Ye Luo, [email protected], Argonne National Laboratory
//
// File created by: Jeongnim Kim, [email protected], University of Illinois at Urbana-Champaign
//////////////////////////////////////////////////////////////////////////////////////
Expand All @@ -16,46 +17,33 @@
/** @file
* @brief Definition of a builder class for PWOrbitalSet
*/
#include "PWOrbitalBuilder.h"
#include "PWOrbitalSetBuilder.h"
#include "QMCWaveFunctions/PlaneWave/PWParameterSet.h"
#include "QMCWaveFunctions/Fermion/DiracDeterminant.h"
#include "QMCWaveFunctions/Fermion/SlaterDet.h"
#include "OhmmsData/ParameterSet.h"
#include "OhmmsData/AttributeSet.h"
#include "Message/Communicate.h"

namespace qmcplusplus
{
PWOrbitalBuilder::PWOrbitalBuilder(Communicate* comm, ParticleSet& els, const PSetMap& psets)
: WaveFunctionComponentBuilder(comm, els), ptclPool(psets), myParam{std::make_unique<PWParameterSet>(comm)}, hfile{comm} {}

PWOrbitalBuilder::~PWOrbitalBuilder() = default;

//All data parsing is handled here, outside storage classes.
std::unique_ptr<WaveFunctionComponent> PWOrbitalBuilder::buildComponent(xmlNodePtr cur)
PWOrbitalSetBuilder::PWOrbitalSetBuilder(const ParticleSet& p, Communicate* comm, xmlNodePtr cur)
: SPOSetBuilder("Planewave", comm),
targetPtcl(p),
rootNode(cur),
myParam{std::make_unique<PWParameterSet>(comm)},
hfile{comm}
{
std::unique_ptr<WaveFunctionComponent> slater_det;
//save the parent
rootNode = cur;
//
//Get wavefunction data and parameters from XML and HDF5
//

//close it if open
hfile.close();
//catch parameters
myParam->put(cur);
//check the current href
bool success = getH5(cur, "href");
//no file, check the root
if (!success)
success = getH5(rootNode, "href");
//Move through the XML tree and read basis information
cur = cur->children;
while (cur != nullptr)
{
std::string cname((const char*)(cur->name));
processChildren(cur, [&](const std::string& cname, const xmlNodePtr element) {
if (cname == "basisset")
{
const std::string a(getXMLAttributeValue(cur, "ecut"));
const std::string a(getXMLAttributeValue(element, "ecut"));
if (!a.empty())
myParam->Ecut = std::stod(a);
}
Expand All @@ -64,77 +52,28 @@ std::unique_ptr<WaveFunctionComponent> PWOrbitalBuilder::buildComponent(xmlNodeP
//close
if (success)
hfile.close();
success = getH5(cur, "hdata");
}
else if (cname == sd_tag)
{
if (!success)
success = getH5(cur, "href");
if (!success)
{
APP_ABORT(" Cannot create a SlaterDet due to missing h5 file\n");
OHMMS::Controller->abort();
}
createPWBasis(cur);
slater_det = putSlaterDet(cur);
success = getH5(element, "hdata");
}
cur = cur->next;
}
hfile.close();
return slater_det;
});

if (!success)
throw std::runtime_error("h5 cannot be open for creating PW basis!");
//create PW Basis
createPWBasis();
}

std::unique_ptr<WaveFunctionComponent> PWOrbitalBuilder::putSlaterDet(xmlNodePtr cur)
{
//catch parameters
myParam->put(cur);
PWOrbitalSetBuilder::~PWOrbitalSetBuilder() = default;

std::vector<std::unique_ptr<DiracDeterminantBase>> dets;
std::unique_ptr<SPOSet> PWOrbitalSetBuilder::createSPOSetFromXML(xmlNodePtr cur)
{
int spin_group = 0;
cur = cur->children;
while (cur != NULL)
{
std::string cname((const char*)(cur->name));
//Which determinant?
if (cname == "determinant")
{
std::string id("updet");
std::string ref("0");
OhmmsAttributeSet aAttrib;
aAttrib.add(id, "id");
aAttrib.add(ref, "ref");
aAttrib.put(cur);
if (ref == "0")
ref = id;
const int firstIndex = targetPtcl.first(spin_group);
const int lastIndex = targetPtcl.last(spin_group);
std::map<std::string, SPOSetPtr>::iterator lit(spomap.find(ref));
//int spin_group=0;
if (lit == spomap.end())
{
app_log() << " Create a PWOrbitalSet" << std::endl;
std::unique_ptr<SPOSet> psi(createPW(cur, spin_group));
spomap[ref] = psi.get();
dets.push_back(std::make_unique<DiracDeterminant<>>(std::move(psi), firstIndex, lastIndex));
}
else
{
app_log() << " Reuse a PWOrbitalSet" << std::endl;
std::unique_ptr<SPOSet> psi((*lit).second->makeClone());
dets.push_back(std::make_unique<DiracDeterminant<>>(std::move(psi), firstIndex, lastIndex));
}
app_log() << " spin=" << spin_group << " id=" << id << " ref=" << ref << std::endl;
spin_group++;
}
cur = cur->next;
}

if (spin_group)
return std::make_unique<SlaterDet>(targetPtcl, std::move(dets));
;

myComm->barrier_and_abort(" Failed to create a SlaterDet at PWOrbitalBuilder::putSlaterDet ");
return nullptr;
std::string spo_object_name;
OhmmsAttributeSet aAttrib;
aAttrib.add(spin_group, "spindataset");
aAttrib.add(spo_object_name, "name");
aAttrib.add(spo_object_name, "id");
aAttrib.put(cur);
return createPW(cur, spo_object_name, spin_group);
}

/** The read routine - get data from XML and H5. Process it and build orbitals.
Expand All @@ -146,10 +85,10 @@ std::unique_ptr<WaveFunctionComponent> PWOrbitalBuilder::putSlaterDet(xmlNodePtr
* -- maximum_ecut
* - basis
*/
bool PWOrbitalBuilder::createPWBasis(xmlNodePtr cur)
bool PWOrbitalSetBuilder::createPWBasis()
{
//recycle int and double reader
int idata;
int idata = 0;
//start of parameters
hfile.read(idata, "electrons/number_of_kpoints");
int nkpts = idata;
Expand All @@ -174,9 +113,8 @@ bool PWOrbitalBuilder::createPWBasis(xmlNodePtr cur)
hfile.read(TwistAngle_DP, "/electrons/kpoint_0/reduced_k");
TwistAngle = TwistAngle_DP;
if (!myBasisSet)
{
myBasisSet = std::make_unique<PWBasis>(TwistAngle);
}

//Read the planewave basisset.
//Note that the same data is opened here for each twist angle-avoids duplication in the
//h5 file (which may become very large).
Expand All @@ -193,7 +131,7 @@ bool PWOrbitalBuilder::createPWBasis(xmlNodePtr cur)
return true;
}

SPOSet* PWOrbitalBuilder::createPW(xmlNodePtr cur, int spinIndex)
std::unique_ptr<SPOSet> PWOrbitalSetBuilder::createPW(xmlNodePtr cur, const std::string& objname, int spinIndex)
{
int nb = targetPtcl.last(spinIndex) - targetPtcl.first(spinIndex);
std::vector<int> occBand(nb);
Expand All @@ -216,7 +154,6 @@ SPOSet* PWOrbitalBuilder::createPW(xmlNodePtr cur, int spinIndex)
std::string occMode("ground");
int bandoffset(1);
OhmmsAttributeSet aAttrib;
aAttrib.add(spinIndex, "spindataset");
aAttrib.add(occMode, "mode");
aAttrib.add(bandoffset, "offset"); /* reserved for index offset */
aAttrib.put(cur);
Expand Down Expand Up @@ -252,7 +189,7 @@ SPOSet* PWOrbitalBuilder::createPW(xmlNodePtr cur, int spinIndex)
hfile.push("electrons", false);
hfile.push("kpoint_0", false);
//create a single-particle orbital set
SPOSetType* psi = new SPOSetType(getXMLAttributeValue(cur, "name"));
auto psi = std::make_unique<SPOSetType>(objname);
if (transform2grid)
{
nb = myParam->numBands;
Expand Down Expand Up @@ -314,11 +251,11 @@ SPOSet* PWOrbitalBuilder::createPW(xmlNodePtr cur, int spinIndex)
}

#if defined(QMC_COMPLEX)
void PWOrbitalBuilder::transform2GridData(PWBasis::GIndex_t& nG, int spinIndex, PWOrbitalSet& pwFunc)
void PWOrbitalSetBuilder::transform2GridData(PWBasis::GIndex_t& nG, int spinIndex, PWOrbitalSet& pwFunc)
{
std::ostringstream splineTag;
splineTag << "eigenstates_" << nG[0] << "_" << nG[1] << "_" << nG[2];
herr_t status = H5Eset_auto2(H5E_DEFAULT, NULL, NULL);
herr_t status = H5Eset_auto2(H5E_DEFAULT, NULL, NULL);
std::string splineTagStr = splineTag.str();
app_log() << " splineTag " << splineTagStr << std::endl;
if (!hfile.is_group(splineTagStr))
Expand Down Expand Up @@ -352,7 +289,6 @@ void PWOrbitalBuilder::transform2GridData(PWBasis::GIndex_t& nG, int spinIndex,
{
bname = myParam->getSpinName(spinIndex);
hfile.push(bname, true);
}
}
for (int ig = 0; ig < nG[0]; ig++)
{
Expand Down Expand Up @@ -380,6 +316,8 @@ void PWOrbitalBuilder::transform2GridData(PWBasis::GIndex_t& nG, int spinIndex,
for (int ib = 0; ib < nb; ib++)
inData.push_back(std::make_unique<StorageType>(nG[0], nG[1], nG[2]));
PosType tAngle = targetPtcl.getLattice().k_cart(TwistAngle);
ParticleSet ptemp(targetPtcl.getSimulationCell());
ptemp.create({1});
PWOrbitalSet::ValueVector phi(nb);
for (int ig = 0; ig < nG[0]; ig++)
{
Expand All @@ -389,9 +327,9 @@ void PWOrbitalBuilder::transform2GridData(PWBasis::GIndex_t& nG, int spinIndex,
RealType y = jg * dy;
for (int kg = 0; kg < nG[2]; kg++)
{
targetPtcl.R[0] = lattice.toCart(PosType(x, y, kg * dz));
pwFunc.evaluateValue(targetPtcl, 0, phi);
RealType x(dot(targetPtcl.R[0], tAngle));
ptemp.R[0] = lattice.toCart(PosType(x, y, kg * dz));
pwFunc.evaluateValue(ptemp, 0, phi);
RealType x(dot(ptemp.R[0], tAngle));
ValueType phase(std::cos(x), -std::sin(x));
for (int ib = 0; ib < nb; ib++)
(*inData[ib])(ig, jg, kg) = phase * phi[ib];
Expand Down Expand Up @@ -419,7 +357,7 @@ void PWOrbitalBuilder::transform2GridData(PWBasis::GIndex_t& nG, int spinIndex,
}
#endif

bool PWOrbitalBuilder::getH5(xmlNodePtr cur, const char* aname)
bool PWOrbitalSetBuilder::getH5(xmlNodePtr cur, const char* aname)
{
const std::string a(getXMLAttributeValue(cur, aname));
if (a.empty())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
// This file is distributed under the University of Illinois/NCSA Open Source License.
// See LICENSE file in top directory for details.
//
// Copyright (c) 2016 Jeongnim Kim and QMCPACK developers.
// Copyright (c) 2023 QMCPACK developers.
//
// File developed by: Jeongnim Kim, [email protected], University of Illinois at Urbana-Champaign
// File developed by: Ye Luo, [email protected], Argonne National Laboratory
// Jeongnim Kim, [email protected], University of Illinois at Urbana-Champaign
// Jeremy McMinnis, [email protected], University of Illinois at Urbana-Champaign
//
// File created by: Jeongnim Kim, [email protected], University of Illinois at Urbana-Champaign
Expand All @@ -15,23 +16,25 @@
* @brief Declaration of a builder class for PWOrbitalSet
*
*/
#ifndef QMCPLUSPLUS_PLANEWAVE_ORBITALBUILD_V0_H
#define QMCPLUSPLUS_PLANEWAVE_ORBITALBUILD_V0_H
#include "QMCWaveFunctions/WaveFunctionComponentBuilder.h"
#ifndef QMCPLUSPLUS_PWORBITAL_BUILDER_H
#define QMCPLUSPLUS_PWORBITAL_BUILDER_H

#include "SPOSetBuilder.h"
#include "hdf/hdf_archive.h"
#if defined(QMC_COMPLEX)
#include "QMCWaveFunctions/PlaneWave/PWOrbitalSet.h"
#else
#include "QMCWaveFunctions/PlaneWave/PWRealOrbitalSet.h"
#endif

namespace qmcplusplus
{
struct PWParameterSet;
class SlaterDet;

/** OrbitalBuilder for Slater determinants in PW basis
*/
class PWOrbitalBuilder : public WaveFunctionComponentBuilder
class PWOrbitalSetBuilder : public SPOSetBuilder
{
private:
#if defined(QMC_COMPLEX)
Expand All @@ -40,8 +43,8 @@ class PWOrbitalBuilder : public WaveFunctionComponentBuilder
using SPOSetType = PWRealOrbitalSet;
#endif

std::map<std::string, SPOSetPtr> spomap;
const PSetMap& ptclPool;
/// target particle set
const ParticleSet& targetPtcl;
///xml node for determinantset
xmlNodePtr rootNode{nullptr};
///input twist angle
Expand All @@ -52,19 +55,19 @@ class PWOrbitalBuilder : public WaveFunctionComponentBuilder
std::unique_ptr<PWBasis> myBasisSet;
///hdf5 handler to clean up
hdf_archive hfile;

public:
///constructor
PWOrbitalBuilder(Communicate* comm, ParticleSet& els, const PSetMap& psets);
~PWOrbitalBuilder() override;
PWOrbitalSetBuilder(const ParticleSet& p, Communicate* comm, xmlNodePtr cur);
~PWOrbitalSetBuilder() override;

///implement vritual function
std::unique_ptr<WaveFunctionComponent> buildComponent(xmlNodePtr cur) override;
/// create an sposet from xml and save the resulting SPOSet
std::unique_ptr<SPOSet> createSPOSetFromXML(xmlNodePtr cur) override;

private:
bool getH5(xmlNodePtr cur, const char* aname);
std::unique_ptr<WaveFunctionComponent> putSlaterDet(xmlNodePtr cur);
bool createPWBasis(xmlNodePtr cur);
SPOSet* createPW(xmlNodePtr cur, int spinIndex);
bool createPWBasis();
std::unique_ptr<SPOSet> createPW(xmlNodePtr cur, const std::string& objname, int spinIndex);
#if defined(QMC_COMPLEX)
void transform2GridData(PWBasis::GIndex_t& nG, int spinIndex, PWOrbitalSet& pwFunc);
#endif
Expand Down
Loading

0 comments on commit 5d9bfeb

Please sign in to comment.