Skip to content

Commit

Permalink
Add SplineC2RT
Browse files Browse the repository at this point in the history
Signed-off-by: Steven Hahn <[email protected]>
  • Loading branch information
quantumsteve committed Sep 8, 2023
1 parent abf2ce5 commit 7a2cd08
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 36 deletions.
56 changes: 28 additions & 28 deletions src/QMCWaveFunctions/BsplineFactory/SplineC2RT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@

namespace qmcplusplus
{
template<typename ST>
SplineC2R<ST>::SplineC2R(const SplineC2R& in) = default;
template<typename T>
SplineC2RT<T>::SplineC2RT(const SplineC2R& in) = default;

template<typename ST>
inline void SplineC2R<ST>::set_spline(SingleSplineType* spline_r,
template<typename T>
inline void SplineC2RT<T>::set_spline(SingleSplineType* spline_r,
SingleSplineType* spline_i,
int twist,
int ispline,
Expand All @@ -36,26 +36,26 @@ inline void SplineC2R<ST>::set_spline(SingleSplineType* spline_r,
SplineInst->copy_spline(spline_i, 2 * ispline + 1);
}

template<typename ST>
bool SplineC2R<ST>::read_splines(hdf_archive& h5f)
template<typename T>
bool SplineC2RT<T>::read_splines(hdf_archive& h5f)
{
std::ostringstream o;
o << "spline_" << MyIndex;
einspline_engine<SplineType> bigtable(SplineInst->getSplinePtr());
return h5f.readEntry(bigtable, o.str().c_str()); //"spline_0");
}

template<typename ST>
bool SplineC2R<ST>::write_splines(hdf_archive& h5f)
template<typename T>
bool SplineC2RT<T>::write_splines(hdf_archive& h5f)
{
std::ostringstream o;
o << "spline_" << MyIndex;
einspline_engine<SplineType> bigtable(SplineInst->getSplinePtr());
return h5f.writeEntry(bigtable, o.str().c_str()); //"spline_0");
}

template<typename ST>
inline void SplineC2R<ST>::assign_v(const PointType& r,
template<typename T>
inline void SplineC2RT<T>::assign_v(const PointType& r,
const vContainer_type& myV,
ValueVector& psi,
int first,
Expand Down Expand Up @@ -99,8 +99,8 @@ inline void SplineC2R<ST>::assign_v(const PointType& r,
}
}

template<typename ST>
void SplineC2R<ST>::evaluateValue(const ParticleSet& P, const int iat, ValueVector& psi)
template<typename T>
void SplineC2RT<T>::evaluateValue(const ParticleSet& P, const int iat, ValueVector& psi)
{
const PointType& r = P.activeR(iat);
PointType ru(PrimLattice.toUnit_floor(r));
Expand All @@ -115,8 +115,8 @@ void SplineC2R<ST>::evaluateValue(const ParticleSet& P, const int iat, ValueVect
}
}

template<typename ST>
void SplineC2R<ST>::evaluateDetRatios(const VirtualParticleSet& VP,
template<typename T>
void SplineC2RT<T>::evaluateDetRatios(const VirtualParticleSet& VP,
ValueVector& psi,
const ValueVector& psiinv,
std::vector<TT>& ratios)
Expand Down Expand Up @@ -163,8 +163,8 @@ void SplineC2R<ST>::evaluateDetRatios(const VirtualParticleSet& VP,

/** assign_vgl
*/
template<typename ST>
inline void SplineC2R<ST>::assign_vgl(const PointType& r,
template<typename T>
inline void SplineC2RT<T>::assign_vgl(const PointType& r,
ValueVector& psi,
GradVector& dpsi,
ValueVector& d2psi,
Expand Down Expand Up @@ -316,8 +316,8 @@ inline void SplineC2R<ST>::assign_vgl(const PointType& r,

/** assign_vgl_from_l can be used when myL is precomputed and myV,myG,myL in cartesian
*/
template<typename ST>
inline void SplineC2R<ST>::assign_vgl_from_l(const PointType& r, ValueVector& psi, GradVector& dpsi, ValueVector& d2psi)
template<typename T>
inline void SplineC2RT<T>::assign_vgl_from_l(const PointType& r, ValueVector& psi, GradVector& dpsi, ValueVector& d2psi)
{
constexpr ST two(2);
const ST x = r[0], y = r[1], z = r[2];
Expand Down Expand Up @@ -431,8 +431,8 @@ inline void SplineC2R<ST>::assign_vgl_from_l(const PointType& r, ValueVector& ps
}
}

template<typename ST>
void SplineC2R<ST>::evaluateVGL(const ParticleSet& P,
template<typename T>
void SplineC2RT<T>::evaluateVGL(const ParticleSet& P,
const int iat,
ValueVector& psi,
GradVector& dpsi,
Expand All @@ -451,8 +451,8 @@ void SplineC2R<ST>::evaluateVGL(const ParticleSet& P,
}
}

template<typename ST>
void SplineC2R<ST>::assign_vgh(const PointType& r,
template<typename T>
void SplineC2RT<T>::assign_vgh(const PointType& r,
ValueVector& psi,
GradVector& dpsi,
HessVector& grad_grad_psi,
Expand Down Expand Up @@ -675,8 +675,8 @@ void SplineC2R<ST>::assign_vgh(const PointType& r,
}
}

template<typename ST>
void SplineC2R<ST>::evaluateVGH(const ParticleSet& P,
template<typename T>
void SplineC2RT<T>::evaluateVGH(const ParticleSet& P,
const int iat,
ValueVector& psi,
GradVector& dpsi,
Expand All @@ -694,8 +694,8 @@ void SplineC2R<ST>::evaluateVGH(const ParticleSet& P,
}
}

template<typename ST>
void SplineC2R<ST>::assign_vghgh(const PointType& r,
template<typename T>
void SplineC2RT<T>::assign_vghgh(const PointType& r,
ValueVector& psi,
GradVector& dpsi,
HessVector& grad_grad_psi,
Expand Down Expand Up @@ -1182,7 +1182,7 @@ void SplineC2R<ST>::evaluateVGHGH(const ParticleSet& P,
}
}

template class SplineC2R<float>;
template class SplineC2R<double>;
template class SplineC2R<std::complex<float>>;
template class SplineC2R<std::complex<double>>;

} // namespace qmcplusplus
17 changes: 9 additions & 8 deletions src/QMCWaveFunctions/BsplineFactory/SplineC2RT.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,18 @@ namespace qmcplusplus
* The rest complex splines produce 1 real orbital.
* All the output orbitals are real (C2R). The maximal number of output orbitals is OrbitalSetSize.
*/
template<typename ST>
class SplineC2R : public BsplineSet
template<typename T>
class SplineC2RT : public BsplineSetT<T>
{
public:
using SplineType = typename bspline_traits<ST, 3>::SplineType;
using BCType = typename bspline_traits<ST, 3>::BCType;
using DataType = ST;
using DataType = T::value_type;
using PointType = TinyVector<ST, 3>;
using SingleSplineType = UBspline_3d_d;
// types for evaluation results
using TT = typename BsplineSet::ValueType;
using ST = T::value_type;
using TT = typename BsplineSetT<T>::ValueType;
using BsplineSet::GGGVector;
using BsplineSet::GradVector;
using BsplineSet::HessVector;
Expand Down Expand Up @@ -84,9 +85,9 @@ class SplineC2R : public BsplineSet
ghContainer_type mygH;

public:
SplineC2R(const std::string& my_name) : BsplineSet(my_name), nComplexBands(0) {}
SplineC2RT(const std::string& my_name) : BsplineSet(my_name), nComplexBands(0) {}

SplineC2R(const SplineC2R& in);
SplineC2RT(const SplineC2R& in);
virtual std::string getClassName() const override { return "SplineC2R"; }
virtual std::string getKeyword() const override { return "SplineC2R"; }
bool isComplex() const override { return true; };
Expand Down Expand Up @@ -210,8 +211,8 @@ class SplineC2R : public BsplineSet
friend struct BsplineReaderBase;
};

extern template class SplineC2R<float>;
extern template class SplineC2R<double>;
extern template class SplineC2R<std::complex<float>>;
extern template class SplineC2R<std::complex<double>>;

} // namespace qmcplusplus
#endif

0 comments on commit 7a2cd08

Please sign in to comment.