Skip to content

Commit

Permalink
Specialize functions in RotatedSPOsT
Browse files Browse the repository at this point in the history
Fix function signature
  • Loading branch information
williamfgc committed Aug 28, 2023
1 parent acb8862 commit 9c61923
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 23 deletions.
6 changes: 3 additions & 3 deletions src/QMCWaveFunctions/RotatedSPOsT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -975,9 +975,9 @@ void RotatedSPOsT<T>::evaluateDerivatives(ParticleSet& P,
template<typename T>
void RotatedSPOsT<T>::evaluateDerivativesWF(ParticleSet& P,
const opt_variables_type& optvars,
Vector<T>& dlogpsi,
const FullRealType& psiCurrent,
const std::vector<T>& Coeff,
Vector<ValueType>& dlogpsi,
const ValueType& psiCurrent,
const std::vector<ValueType>& Coeff,
const std::vector<size_t>& C2node_up,
const std::vector<size_t>& C2node_dn,
const ValueVector& detValues_up,
Expand Down
7 changes: 4 additions & 3 deletions src/QMCWaveFunctions/RotatedSPOsT.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class RotatedSPOsT : public SPOSetT<T>, public OptimizableObject
public:
using IndexType = typename SPOSetT<T>::IndexType;
using RealType = typename SPOSetT<T>::RealType;
using ValueType = typename SPOSetT<T>::ValueType;
using FullRealType = typename SPOSetT<T>::FullRealType;
using ValueVector = typename SPOSetT<T>::ValueVector;
using ValueMatrix = typename SPOSetT<T>::ValueMatrix;
Expand Down Expand Up @@ -200,9 +201,9 @@ class RotatedSPOsT : public SPOSetT<T>, public OptimizableObject

void evaluateDerivativesWF(ParticleSet& P,
const opt_variables_type& optvars,
Vector<T>& dlogpsi,
const FullRealType& psiCurrent,
const std::vector<T>& Coeff,
Vector<ValueType>& dlogpsi,
const ValueType& psiCurrent,
const std::vector<ValueType>& Coeff,
const std::vector<size_t>& C2node_up,
const std::vector<size_t>& C2node_dn,
const ValueVector& detValues_up,
Expand Down
86 changes: 73 additions & 13 deletions src/QMCWaveFunctions/SPOSetBuilderT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
#include "SPOSetBuilderT.h"
#include "OhmmsData/AttributeSet.h"
#include <Message/UniformCommunicateError.h>

#ifndef QMC_COMPLEX
#include "QMCWaveFunctions/RotatedSPOsT.h"
#endif
#include "QMCWaveFunctions/RotatedSPOsT.h" // only for real wavefunctions

namespace qmcplusplus
{
Expand Down Expand Up @@ -133,8 +130,8 @@ std::unique_ptr<SPOSetT<T>> SPOSetBuilderT<T>::createSPOSet(xmlNodePtr cur)
return sposet;
}

template<typename T>
std::unique_ptr<SPOSetT<T>> SPOSetBuilderT<T>::createRotatedSPOSet(xmlNodePtr cur)
template<>
std::unique_ptr<SPOSetT<float>> SPOSetBuilderT<float>::createRotatedSPOSet(xmlNodePtr cur)
{
std::string spo_object_name;
std::string method;
Expand All @@ -143,12 +140,49 @@ std::unique_ptr<SPOSetT<T>> SPOSetBuilderT<T>::createRotatedSPOSet(xmlNodePtr cu
attrib.add(method, "method", {"global", "history"});
attrib.put(cur);

std::unique_ptr<SPOSetT<float>> sposet;
processChildren(cur, [&](const std::string& cname, const xmlNodePtr element) {
if (cname == "sposet")
{
sposet = createSPOSet(element);
}
});

if (!sposet)
myComm->barrier_and_abort("Rotated SPO needs an SPOset");

if (!sposet->isRotationSupported())
myComm->barrier_and_abort("Orbital rotation not supported with '" + sposet->getName() + "' of type '" +
sposet->getClassName() + "'.");

sposet->storeParamsBeforeRotation();
auto rot_spo = std::make_unique<RotatedSPOsT<float>>(spo_object_name, std::move(sposet));

if (method == "history")
rot_spo->set_use_global_rotation(false);

#ifdef QMC_COMPLEX
myComm->barrier_and_abort("Orbital optimization via rotation doesn't support complex wavefunctions yet.");
return nullptr;
#else
std::unique_ptr<SPOSetT<T>> sposet;
processChildren(cur, [&](const std::string& cname, const xmlNodePtr element) {
if (cname == "opt_vars")
{
std::vector<RealType> params;
putContent(params, element);
rot_spo->setRotationParameters(params);
}
});
return rot_spo;
}

template<>
std::unique_ptr<SPOSetT<double>> SPOSetBuilderT<double>::createRotatedSPOSet(xmlNodePtr cur)
{
std::string spo_object_name;
std::string method;
OhmmsAttributeSet attrib;
attrib.add(spo_object_name, "name");
attrib.add(method, "method", {"global", "history"});
attrib.put(cur);

std::unique_ptr<SPOSetT<double>> sposet;
processChildren(cur, [&](const std::string& cname, const xmlNodePtr element) {
if (cname == "sposet")
{
Expand All @@ -164,7 +198,7 @@ std::unique_ptr<SPOSetT<T>> SPOSetBuilderT<T>::createRotatedSPOSet(xmlNodePtr cu
sposet->getClassName() + "'.");

sposet->storeParamsBeforeRotation();
auto rot_spo = std::make_unique<RotatedSPOsT<T>>(spo_object_name, std::move(sposet));
auto rot_spo = std::make_unique<RotatedSPOsT<double>>(spo_object_name, std::move(sposet));

if (method == "history")
rot_spo->set_use_global_rotation(false);
Expand All @@ -178,8 +212,34 @@ std::unique_ptr<SPOSetT<T>> SPOSetBuilderT<T>::createRotatedSPOSet(xmlNodePtr cu
}
});
return rot_spo;
#endif
}

template<>
std::unique_ptr<SPOSetT<std::complex<float>>> SPOSetBuilderT<std::complex<float>>::createRotatedSPOSet(xmlNodePtr cur)
{
std::string spo_object_name;
std::string method;
OhmmsAttributeSet attrib;
attrib.add(spo_object_name, "name");
attrib.add(method, "method", {"global", "history"});
attrib.put(cur);
myComm->barrier_and_abort("Orbital optimization via rotation doesn't support complex wavefunctions yet.");
return nullptr;
}

template<>
std::unique_ptr<SPOSetT<std::complex<double>>> SPOSetBuilderT<std::complex<double>>::createRotatedSPOSet(xmlNodePtr cur)
{
std::string spo_object_name;
std::string method;
OhmmsAttributeSet attrib;
attrib.add(spo_object_name, "name");
attrib.add(method, "method", {"global", "history"});
attrib.put(cur);
myComm->barrier_and_abort("Orbital optimization via rotation doesn't support complex wavefunctions yet.");
return nullptr;
}

template class SPOSetBuilderT<double>;
template class SPOSetBuilderT<float>;
template class SPOSetBuilderT<std::complex<double>>;
Expand Down
4 changes: 2 additions & 2 deletions src/QMCWaveFunctions/SPOSetT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,8 @@ void SPOSetT<T>::evaluateDerivatives(ParticleSet& P,
template<class T>
void SPOSetT<T>::evaluateDerivativesWF(ParticleSet& P,
const opt_variables_type& optvars,
Vector<T>& dlogpsi,
const typename QTFull::ValueType& psiCurrent,
Vector<ValueType>& dlogpsi,
const ValueType& psiCurrent,
const std::vector<T>& Coeff,
const std::vector<size_t>& C2node_up,
const std::vector<size_t>& C2node_dn,
Expand Down
4 changes: 2 additions & 2 deletions src/QMCWaveFunctions/SPOSetT.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ class SPOSetT : public QMCTraits
*/
virtual void evaluateDerivativesWF(ParticleSet& P,
const opt_variables_type& optvars,
Vector<T>& dlogpsi,
const typename QTFull::ValueType& psiCurrent,
Vector<ValueType>& dlogpsi,
const ValueType& psiCurrent,
const std::vector<T>& Coeff,
const std::vector<size_t>& C2node_up,
const std::vector<size_t>& C2node_dn,
Expand Down

0 comments on commit 9c61923

Please sign in to comment.