Skip to content

Commit

Permalink
Change PWOrbitalBuilder to PWOrbitalSetBuilder
Browse files Browse the repository at this point in the history
  • Loading branch information
ye-luo committed Sep 28, 2023
1 parent 577fb82 commit 6ebe8f6
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 120 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ std::unique_ptr<SPOSet> EinsplineSetBuilder::createSPOSetFromXML(xmlNodePtr cur)
"no"); // use old spline library for high-order derivatives, e.g. needed for backflow optimization
std::string useGPU;
std::string GPUsharing = "no";
std::string spo_object_name;

ScopedTimer spo_timer_scope(createGlobalTimer("einspline::CreateSPOSetFromXML", timer_level_medium));

Expand Down Expand Up @@ -177,8 +176,6 @@ std::unique_ptr<SPOSet> EinsplineSetBuilder::createSPOSetFromXML(xmlNodePtr cur)
{
OhmmsAttributeSet oAttrib;
oAttrib.add(spinSet, "spindataset");
oAttrib.add(spo_object_name, "name");
oAttrib.add(spo_object_name, "id");
oAttrib.put(cur);
}

Expand Down
2 changes: 1 addition & 1 deletion src/QMCWaveFunctions/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ if(OHMMS_DIM MATCHES 3)
endif(HAVE_EINSPLINE)

# plane wave SPO
set(FERMION_SRCS ${FERMION_SRCS} PlaneWave/PWBasis.cpp PlaneWave/PWParameterSet.cpp PlaneWave/PWOrbitalBuilder.cpp)
set(FERMION_SRCS ${FERMION_SRCS} PlaneWave/PWBasis.cpp PlaneWave/PWParameterSet.cpp PlaneWave/PWOrbitalSetBuilder.cpp)
if(QMC_COMPLEX)
set(FERMION_SRCS ${FERMION_SRCS} PlaneWave/PWOrbitalSet.cpp)
else()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,38 +16,23 @@
/** @file
* @brief Definition of a builder class for PWOrbitalSet
*/
#include "PWOrbitalBuilder.h"
#include "PWOrbitalSetBuilder.h"
#include "QMCWaveFunctions/PlaneWave/PWParameterSet.h"
#include "QMCWaveFunctions/Fermion/DiracDeterminant.h"
#include "QMCWaveFunctions/Fermion/SlaterDet.h"
#include "OhmmsData/ParameterSet.h"
#include "OhmmsData/AttributeSet.h"
#include "Message/Communicate.h"

namespace qmcplusplus
{
PWOrbitalBuilder::PWOrbitalBuilder(Communicate* comm, ParticleSet& els, const PSetMap& psets)
: WaveFunctionComponentBuilder(comm, els), ptclPool(psets), myParam{std::make_unique<PWParameterSet>(comm)}, hfile{comm} {}

PWOrbitalBuilder::~PWOrbitalBuilder() = default;

//All data parsing is handled here, outside storage classes.
std::unique_ptr<WaveFunctionComponent> PWOrbitalBuilder::buildComponent(xmlNodePtr cur)
{
std::unique_ptr<WaveFunctionComponent> slater_det;
//save the parent
rootNode = cur;
PWOrbitalSetBuilder::PWOrbitalSetBuilder(const ParticleSet& p, Communicate* comm, xmlNodePtr cur)
: SPOSetBuilder("Planewave", comm), targetPtcl(p), rootNode(cur), myParam{std::make_unique<PWParameterSet>(comm)}, hfile{comm} {
//
//Get wavefunction data and parameters from XML and HDF5
//

//close it if open
hfile.close();
//catch parameters
myParam->put(cur);
//check the current href
bool success = getH5(cur, "href");
//no file, check the root
if (!success)
success = getH5(rootNode, "href");
//Move through the XML tree and read basis information
cur = cur->children;
while (cur != nullptr)
Expand All @@ -66,75 +51,26 @@ std::unique_ptr<WaveFunctionComponent> PWOrbitalBuilder::buildComponent(xmlNodeP
hfile.close();
success = getH5(cur, "hdata");
}
else if (cname == sd_tag)
{
if (!success)
success = getH5(cur, "href");
if (!success)
{
APP_ABORT(" Cannot create a SlaterDet due to missing h5 file\n");
OHMMS::Controller->abort();
}
createPWBasis(cur);
slater_det = putSlaterDet(cur);
}
cur = cur->next;
}
hfile.close();
return slater_det;

//create PW Basis
createPWBasis(cur);
}

std::unique_ptr<WaveFunctionComponent> PWOrbitalBuilder::putSlaterDet(xmlNodePtr cur)
PWOrbitalSetBuilder::~PWOrbitalSetBuilder() = default;

std::unique_ptr<SPOSet> PWOrbitalSetBuilder::createSPOSetFromXML(xmlNodePtr cur)
{
//catch parameters
myParam->put(cur);

std::vector<std::unique_ptr<DiracDeterminantBase>> dets;
int spin_group = 0;
cur = cur->children;
while (cur != NULL)
{
std::string cname((const char*)(cur->name));
//Which determinant?
if (cname == "determinant")
{
std::string id("updet");
std::string ref("0");
OhmmsAttributeSet aAttrib;
aAttrib.add(id, "id");
aAttrib.add(ref, "ref");
aAttrib.put(cur);
if (ref == "0")
ref = id;
const int firstIndex = targetPtcl.first(spin_group);
const int lastIndex = targetPtcl.last(spin_group);
std::map<std::string, SPOSetPtr>::iterator lit(spomap.find(ref));
//int spin_group=0;
if (lit == spomap.end())
{
app_log() << " Create a PWOrbitalSet" << std::endl;
std::unique_ptr<SPOSet> psi(createPW(cur, spin_group));
spomap[ref] = psi.get();
dets.push_back(std::make_unique<DiracDeterminant<>>(std::move(psi), firstIndex, lastIndex));
}
else
{
app_log() << " Reuse a PWOrbitalSet" << std::endl;
std::unique_ptr<SPOSet> psi((*lit).second->makeClone());
dets.push_back(std::make_unique<DiracDeterminant<>>(std::move(psi), firstIndex, lastIndex));
}
app_log() << " spin=" << spin_group << " id=" << id << " ref=" << ref << std::endl;
spin_group++;
}
cur = cur->next;
}

if (spin_group)
return std::make_unique<SlaterDet>(targetPtcl, std::move(dets));
;

myComm->barrier_and_abort(" Failed to create a SlaterDet at PWOrbitalBuilder::putSlaterDet ");
return nullptr;
std::string spo_object_name;
OhmmsAttributeSet aAttrib;
aAttrib.add(spin_group, "spindataset");
aAttrib.add(spo_object_name, "name");
aAttrib.add(spo_object_name, "id");
aAttrib.put(cur);
return createPW(cur, spo_object_name, spin_group);
}

/** The read routine - get data from XML and H5. Process it and build orbitals.
Expand All @@ -146,7 +82,7 @@ std::unique_ptr<WaveFunctionComponent> PWOrbitalBuilder::putSlaterDet(xmlNodePtr
* -- maximum_ecut
* - basis
*/
bool PWOrbitalBuilder::createPWBasis(xmlNodePtr cur)
bool PWOrbitalSetBuilder::createPWBasis(xmlNodePtr cur)
{
//recycle int and double reader
int idata;
Expand Down Expand Up @@ -193,7 +129,7 @@ bool PWOrbitalBuilder::createPWBasis(xmlNodePtr cur)
return true;
}

SPOSet* PWOrbitalBuilder::createPW(xmlNodePtr cur, int spinIndex)
std::unique_ptr<SPOSet> PWOrbitalSetBuilder::createPW(xmlNodePtr cur, const std::string& objname, int spinIndex)
{
int nb = targetPtcl.last(spinIndex) - targetPtcl.first(spinIndex);
std::vector<int> occBand(nb);
Expand All @@ -216,7 +152,6 @@ SPOSet* PWOrbitalBuilder::createPW(xmlNodePtr cur, int spinIndex)
std::string occMode("ground");
int bandoffset(1);
OhmmsAttributeSet aAttrib;
aAttrib.add(spinIndex, "spindataset");
aAttrib.add(occMode, "mode");
aAttrib.add(bandoffset, "offset"); /* reserved for index offset */
aAttrib.put(cur);
Expand Down Expand Up @@ -252,7 +187,7 @@ SPOSet* PWOrbitalBuilder::createPW(xmlNodePtr cur, int spinIndex)
hfile.push("electrons", false);
hfile.push("kpoint_0", false);
//create a single-particle orbital set
SPOSetType* psi = new SPOSetType(getXMLAttributeValue(cur, "name"));
auto psi = std::make_unique<SPOSetType>(objname);
if (transform2grid)
{
nb = myParam->numBands;
Expand Down Expand Up @@ -314,7 +249,7 @@ SPOSet* PWOrbitalBuilder::createPW(xmlNodePtr cur, int spinIndex)
}

#if defined(QMC_COMPLEX)
void PWOrbitalBuilder::transform2GridData(PWBasis::GIndex_t& nG, int spinIndex, PWOrbitalSet& pwFunc)
void PWOrbitalSetBuilder::transform2GridData(PWBasis::GIndex_t& nG, int spinIndex, PWOrbitalSet& pwFunc)
{
std::ostringstream splineTag;
splineTag << "eigenstates_" << nG[0] << "_" << nG[1] << "_" << nG[2];
Expand Down Expand Up @@ -380,6 +315,8 @@ void PWOrbitalBuilder::transform2GridData(PWBasis::GIndex_t& nG, int spinIndex,
for (int ib = 0; ib < nb; ib++)
inData.push_back(std::make_unique<StorageType>(nG[0], nG[1], nG[2]));
PosType tAngle = targetPtcl.getLattice().k_cart(TwistAngle);
ParticleSet ptemp(targetPtcl.getSimulationCell());
ptemp.create({1});
PWOrbitalSet::ValueVector phi(nb);
for (int ig = 0; ig < nG[0]; ig++)
{
Expand All @@ -389,9 +326,9 @@ void PWOrbitalBuilder::transform2GridData(PWBasis::GIndex_t& nG, int spinIndex,
RealType y = jg * dy;
for (int kg = 0; kg < nG[2]; kg++)
{
targetPtcl.R[0] = lattice.toCart(PosType(x, y, kg * dz));
pwFunc.evaluateValue(targetPtcl, 0, phi);
RealType x(dot(targetPtcl.R[0], tAngle));
ptemp.R[0] = lattice.toCart(PosType(x, y, kg * dz));
pwFunc.evaluateValue(ptemp, 0, phi);
RealType x(dot(ptemp.R[0], tAngle));
ValueType phase(std::cos(x), -std::sin(x));
for (int ib = 0; ib < nb; ib++)
(*inData[ib])(ig, jg, kg) = phase * phi[ib];
Expand Down Expand Up @@ -419,7 +356,7 @@ void PWOrbitalBuilder::transform2GridData(PWBasis::GIndex_t& nG, int spinIndex,
}
#endif

bool PWOrbitalBuilder::getH5(xmlNodePtr cur, const char* aname)
bool PWOrbitalSetBuilder::getH5(xmlNodePtr cur, const char* aname)
{
const std::string a(getXMLAttributeValue(cur, aname));
if (a.empty())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
// This file is distributed under the University of Illinois/NCSA Open Source License.
// See LICENSE file in top directory for details.
//
// Copyright (c) 2016 Jeongnim Kim and QMCPACK developers.
// Copyright (c) 2023 QMCPACK developers.
//
// File developed by: Jeongnim Kim, [email protected], University of Illinois at Urbana-Champaign
// File developed by: Ye Luo, [email protected], Argonne National Laboratory
// Jeongnim Kim, [email protected], University of Illinois at Urbana-Champaign
// Jeremy McMinnis, [email protected], University of Illinois at Urbana-Champaign
//
// File created by: Jeongnim Kim, [email protected], University of Illinois at Urbana-Champaign
Expand All @@ -15,23 +16,25 @@
* @brief Declaration of a builder class for PWOrbitalSet
*
*/
#ifndef QMCPLUSPLUS_PLANEWAVE_ORBITALBUILD_V0_H
#define QMCPLUSPLUS_PLANEWAVE_ORBITALBUILD_V0_H
#include "QMCWaveFunctions/WaveFunctionComponentBuilder.h"
#ifndef QMCPLUSPLUS_PWORBITAL_BUILDER_H
#define QMCPLUSPLUS_PWORBITAL_BUILDER_H

#include "SPOSetBuilder.h"
#include "hdf/hdf_archive.h"
#if defined(QMC_COMPLEX)
#include "QMCWaveFunctions/PlaneWave/PWOrbitalSet.h"
#else
#include "QMCWaveFunctions/PlaneWave/PWRealOrbitalSet.h"
#endif

namespace qmcplusplus
{
struct PWParameterSet;
class SlaterDet;

/** OrbitalBuilder for Slater determinants in PW basis
*/
class PWOrbitalBuilder : public WaveFunctionComponentBuilder
class PWOrbitalSetBuilder : public SPOSetBuilder
{
private:
#if defined(QMC_COMPLEX)
Expand All @@ -40,8 +43,8 @@ class PWOrbitalBuilder : public WaveFunctionComponentBuilder
using SPOSetType = PWRealOrbitalSet;
#endif

std::map<std::string, SPOSetPtr> spomap;
const PSetMap& ptclPool;
/// target particle set
const ParticleSet& targetPtcl;
///xml node for determinantset
xmlNodePtr rootNode{nullptr};
///input twist angle
Expand All @@ -54,17 +57,16 @@ class PWOrbitalBuilder : public WaveFunctionComponentBuilder
hdf_archive hfile;
public:
///constructor
PWOrbitalBuilder(Communicate* comm, ParticleSet& els, const PSetMap& psets);
~PWOrbitalBuilder() override;
PWOrbitalSetBuilder(const ParticleSet& p, Communicate* comm, xmlNodePtr cur);
~PWOrbitalSetBuilder() override;

///implement vritual function
std::unique_ptr<WaveFunctionComponent> buildComponent(xmlNodePtr cur) override;
/// create an sposet from xml and save the resulting SPOSet
std::unique_ptr<SPOSet> createSPOSetFromXML(xmlNodePtr cur) override;

private:
bool getH5(xmlNodePtr cur, const char* aname);
std::unique_ptr<WaveFunctionComponent> putSlaterDet(xmlNodePtr cur);
bool createPWBasis(xmlNodePtr cur);
SPOSet* createPW(xmlNodePtr cur, int spinIndex);
std::unique_ptr<SPOSet> createPW(xmlNodePtr cur, const std::string& objname, int spinIndex);
#if defined(QMC_COMPLEX)
void transform2GridData(PWBasis::GIndex_t& nG, int spinIndex, PWOrbitalSet& pwFunc);
#endif
Expand Down
18 changes: 12 additions & 6 deletions src/QMCWaveFunctions/SPOSetBuilderFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,24 @@


#include "SPOSetBuilderFactory.h"
#include "QMCWaveFunctions/SPOSetScanner.h"
#include "QMCWaveFunctions/HarmonicOscillator/SHOSetBuilder.h"
#include "SPOSetScanner.h"
#include "HarmonicOscillator/SHOSetBuilder.h"
#include "PlaneWave/PWOrbitalSetBuilder.h"
#include "ModernStringUtils.hpp"
#include "QMCWaveFunctions/ElectronGas/FreeOrbitalBuilder.h"
#include "ElectronGas/FreeOrbitalBuilder.h"
#if OHMMS_DIM == 3
#include "QMCWaveFunctions/LCAO/LCAOrbitalBuilder.h"
#include "LCAO/LCAOrbitalBuilder.h"

#if defined(QMC_COMPLEX)
#include "BsplineFactory/EinsplineSpinorSetBuilder.h"
#include "QMCWaveFunctions/LCAO/LCAOSpinorBuilder.h"
#include "LCAO/LCAOSpinorBuilder.h"
#endif

#if defined(HAVE_EINSPLINE)
#include "BsplineFactory/EinsplineSetBuilder.h"
#endif
#endif
#include "QMCWaveFunctions/CompositeSPOSet.h"
#include "CompositeSPOSet.h"
#include "Utilities/ProgressReportEngine.h"
#include "Utilities/IteratorUtility.h"
#include "OhmmsData/AttributeSet.h"
Expand Down Expand Up @@ -105,6 +106,11 @@ std::unique_ptr<SPOSetBuilder> SPOSetBuilderFactory::createSPOSetBuilder(xmlNode
app_log() << "Harmonic Oscillator SPO set" << std::endl;
bb = std::make_unique<SHOSetBuilder>(targetPtcl, myComm);
}
else if (type == "PWBasis" || type == "PW" || type == "pw")
{
app_log() << "Planewave basis SPO set" << std::endl;
bb = std::make_unique<PWOrbitalSetBuilder>(targetPtcl, myComm, rootNode);
}
#if OHMMS_DIM == 3
else if (type.find("spline") < type.size())
{
Expand Down
5 changes: 0 additions & 5 deletions src/QMCWaveFunctions/WaveFunctionFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include "QMCWaveFunctions/Fermion/SlaterDetBuilder.h"
#include "QMCWaveFunctions/LatticeGaussianProductBuilder.h"
#include "QMCWaveFunctions/ExampleHeBuilder.h"
#include "QMCWaveFunctions/PlaneWave/PWOrbitalBuilder.h"
#if OHMMS_DIM == 3 && !defined(QMC_COMPLEX)
#include "QMCWaveFunctions/AGPDeterminantBuilder.h"
#endif
Expand Down Expand Up @@ -187,10 +186,6 @@ bool WaveFunctionFactory::addFermionTerm(TrialWaveFunction& psi, SPOSetBuilderFa
msg << " please use \"free\" orbitals in sposet_builder" << std::endl;
throw std::runtime_error(msg.str());
}
else if (orbtype == "PWBasis" || orbtype == "PW" || orbtype == "pw")
{
detbuilder = std::make_unique<PWOrbitalBuilder>(myComm, targetPtcl, ptclPool);
}
else
detbuilder = std::make_unique<SlaterDetBuilder>(myComm, spo_factory, targetPtcl, psi, ptclPool);
psi.addComponent(detbuilder->buildComponent(cur));
Expand Down

0 comments on commit 6ebe8f6

Please sign in to comment.