From f466ca3f172917ff622811bce46d20963fe79787 Mon Sep 17 00:00:00 2001 From: William F Godoy Date: Mon, 24 Jul 2023 15:09:01 -0400 Subject: [PATCH] Implement SPOSetT template class Asses the initial effort to refactor SPOSet into templates without consumers or tests. Concretize friend class declaration Define testing::getMyVars for SPOSetT Add FakeSPOT class Move SpinorSet to a templated class Refactor FreeOrbital class Base typed aliases on SPOSet on OrbitalSetTraits Add FullRealType in SPOSet and RotatedSPOs Add this in templated meta class Add explicit function instantions for FreeOrbital Add templated class SHOSetT Add PWRealOrbitalSetT template class Revert test_RotatedSPOs.cpp Signed-off-by: Steven Hahn --- src/QMCWaveFunctions/CMakeLists.txt | 6 +- .../tests/test_RotatedSPOs.cpp | 68 +++++++++++++++++++ 2 files changed, 71 insertions(+), 3 deletions(-) diff --git a/src/QMCWaveFunctions/CMakeLists.txt b/src/QMCWaveFunctions/CMakeLists.txt index 52c6038f3ce..71f019357ee 100644 --- a/src/QMCWaveFunctions/CMakeLists.txt +++ b/src/QMCWaveFunctions/CMakeLists.txt @@ -47,11 +47,11 @@ set(WFBASE_SRCS SpinorSetT.cpp) if(NOT QMC_COMPLEX) - set(WFBASE_SRCS ${WFBASE_SRCS} RotatedSPOs.cpp) + set(WFBASE_SRCS ${WFBASE_SRCS} RotatedSPOs.cpp RotatedSPOsT.cpp) endif(NOT QMC_COMPLEX) if(QMC_COMPLEX) - set(WFBASE_SRCS ${WFBASE_SRCS} SpinorSet.cpp) + set(WFBASE_SRCS ${WFBASE_SRCS} SpinorSet.cpp SpinorSetT.cpp) endif(QMC_COMPLEX) ######################## # build jastrows @@ -149,7 +149,7 @@ if(OHMMS_DIM MATCHES 3) if(QMC_COMPLEX) set(FERMION_SRCS ${FERMION_SRCS} PlaneWave/PWOrbitalSet.cpp) else() - set(FERMION_SRCS ${FERMION_SRCS} PlaneWave/PWRealOrbitalSet.cpp) + set(FERMION_SRCS ${FERMION_SRCS} PlaneWave/PWRealOrbitalSet.cpp PlaneWave/PWRealOrbitalSetT.cpp) endif(QMC_COMPLEX) if(NOT QMC_COMPLEX) diff --git a/src/QMCWaveFunctions/tests/test_RotatedSPOs.cpp b/src/QMCWaveFunctions/tests/test_RotatedSPOs.cpp index af6f5b9cf0d..57b0cf8faf3 100644 --- a/src/QMCWaveFunctions/tests/test_RotatedSPOs.cpp +++ b/src/QMCWaveFunctions/tests/test_RotatedSPOs.cpp @@ -20,6 +20,8 @@ #include "QMCWaveFunctions/WaveFunctionComponent.h" #include "BsplineFactory/EinsplineSetBuilder.h" #include "QMCWaveFunctions/RotatedSPOs.h" +#include "QMCWaveFunctions/RotatedSPOsT.h" +#include "QMCWaveFunctions/SPOSetT.h" #include "checkMatrix.hpp" #include "FakeSPO.h" #include @@ -645,8 +647,22 @@ TEST_CASE("RotatedSPOs construct delta matrix", "[wavefunction]") namespace testing { opt_variables_type& getMyVars(SPOSet& rot) { return rot.myVars; } +opt_variables_type& getMyVars(SPOSetT& rot) { return rot.myVars; } +opt_variables_type& getMyVars(SPOSetT& rot) { return rot.myVars; } +opt_variables_type& getMyVars(SPOSetT>& rot) { return rot.myVars; } +opt_variables_type& getMyVars(SPOSetT>& rot) { return rot.myVars; } opt_variables_type& getMyVarsFull(RotatedSPOs& rot) { return rot.myVarsFull; } +opt_variables_type& getMyVarsFull(RotatedSPOsT& rot) { return rot.myVarsFull; } +opt_variables_type& getMyVarsFull(RotatedSPOsT& rot) { return rot.myVarsFull; } std::vector>& getHistoryParams(RotatedSPOs& rot) { return rot.history_params_; } +std::vector>& getHistoryParams(RotatedSPOsT& rot) +{ + return rot.history_params_; +} +std::vector>& getHistoryParams(RotatedSPOsT& rot) +{ + return rot.history_params_; +} } // namespace testing // Test using global rotation @@ -701,6 +717,58 @@ TEST_CASE("RotatedSPOs read and write parameters", "[wavefunction]") CHECK(full_var[5] == Approx(0.0)); } +// Test using global rotation +TEMPLATE_TEST_CASE("RotatedSPOs read and write parameters", "[wavefunction][template]", double, float) +{ + auto fake_spo = std::make_unique>(); + fake_spo->setOrbitalSetSize(4); + RotatedSPOsT rot("fake_rot", std::move(fake_spo)); + int nel = 2; + rot.buildOptVariables(nel); + + optimize::VariableSet vs; + rot.checkInVariablesExclusive(vs); + vs[0] = 0.1; + vs[1] = 0.15; + vs[2] = 0.2; + vs[3] = 0.25; + rot.resetParametersExclusive(vs); + + { + hdf_archive hout; + vs.writeToHDF("rot_vp.h5", hout); + + rot.writeVariationalParameters(hout); + } + + auto fake_spo2 = std::make_unique>(); + fake_spo2->setOrbitalSetSize(4); + + RotatedSPOsT rot2("fake_rot", std::move(fake_spo2)); + rot2.buildOptVariables(nel); + + optimize::VariableSet vs2; + rot2.checkInVariablesExclusive(vs2); + + hdf_archive hin; + vs2.readFromHDF("rot_vp.h5", hin); + rot2.readVariationalParameters(hin); + + opt_variables_type& var = testing::getMyVars(rot2); + CHECK(var[0] == Approx(vs[0])); + CHECK(var[1] == Approx(vs[1])); + CHECK(var[2] == Approx(vs[2])); + CHECK(var[3] == Approx(vs[3])); + + opt_variables_type& full_var = testing::getMyVarsFull(rot2); + CHECK(full_var[0] == Approx(vs[0])); + CHECK(full_var[1] == Approx(vs[1])); + CHECK(full_var[2] == Approx(vs[2])); + CHECK(full_var[3] == Approx(vs[3])); + CHECK(full_var[4] == Approx(0.0)); + CHECK(full_var[5] == Approx(0.0)); +} + // Test using history list. TEST_CASE("RotatedSPOs read and write parameters history", "[wavefunction]") {