diff --git a/src/QMCWaveFunctions/CMakeLists.txt b/src/QMCWaveFunctions/CMakeLists.txt index 52c6038f3ce..71f019357ee 100644 --- a/src/QMCWaveFunctions/CMakeLists.txt +++ b/src/QMCWaveFunctions/CMakeLists.txt @@ -47,11 +47,11 @@ set(WFBASE_SRCS SpinorSetT.cpp) if(NOT QMC_COMPLEX) - set(WFBASE_SRCS ${WFBASE_SRCS} RotatedSPOs.cpp) + set(WFBASE_SRCS ${WFBASE_SRCS} RotatedSPOs.cpp RotatedSPOsT.cpp) endif(NOT QMC_COMPLEX) if(QMC_COMPLEX) - set(WFBASE_SRCS ${WFBASE_SRCS} SpinorSet.cpp) + set(WFBASE_SRCS ${WFBASE_SRCS} SpinorSet.cpp SpinorSetT.cpp) endif(QMC_COMPLEX) ######################## # build jastrows @@ -149,7 +149,7 @@ if(OHMMS_DIM MATCHES 3) if(QMC_COMPLEX) set(FERMION_SRCS ${FERMION_SRCS} PlaneWave/PWOrbitalSet.cpp) else() - set(FERMION_SRCS ${FERMION_SRCS} PlaneWave/PWRealOrbitalSet.cpp) + set(FERMION_SRCS ${FERMION_SRCS} PlaneWave/PWRealOrbitalSet.cpp PlaneWave/PWRealOrbitalSetT.cpp) endif(QMC_COMPLEX) if(NOT QMC_COMPLEX) diff --git a/src/QMCWaveFunctions/tests/test_RotatedSPOs.cpp b/src/QMCWaveFunctions/tests/test_RotatedSPOs.cpp index af6f5b9cf0d..57b0cf8faf3 100644 --- a/src/QMCWaveFunctions/tests/test_RotatedSPOs.cpp +++ b/src/QMCWaveFunctions/tests/test_RotatedSPOs.cpp @@ -20,6 +20,8 @@ #include "QMCWaveFunctions/WaveFunctionComponent.h" #include "BsplineFactory/EinsplineSetBuilder.h" #include "QMCWaveFunctions/RotatedSPOs.h" +#include "QMCWaveFunctions/RotatedSPOsT.h" +#include "QMCWaveFunctions/SPOSetT.h" #include "checkMatrix.hpp" #include "FakeSPO.h" #include @@ -645,8 +647,22 @@ TEST_CASE("RotatedSPOs construct delta matrix", "[wavefunction]") namespace testing { opt_variables_type& getMyVars(SPOSet& rot) { return rot.myVars; } +opt_variables_type& getMyVars(SPOSetT& rot) { return rot.myVars; } +opt_variables_type& getMyVars(SPOSetT& rot) { return rot.myVars; } +opt_variables_type& getMyVars(SPOSetT>& rot) { return rot.myVars; } +opt_variables_type& getMyVars(SPOSetT>& rot) { return rot.myVars; } opt_variables_type& getMyVarsFull(RotatedSPOs& rot) { return rot.myVarsFull; } +opt_variables_type& getMyVarsFull(RotatedSPOsT& rot) { return rot.myVarsFull; } +opt_variables_type& getMyVarsFull(RotatedSPOsT& rot) { return rot.myVarsFull; } std::vector>& getHistoryParams(RotatedSPOs& rot) { return rot.history_params_; } +std::vector>& getHistoryParams(RotatedSPOsT& rot) +{ + return rot.history_params_; +} +std::vector>& getHistoryParams(RotatedSPOsT& rot) +{ + return rot.history_params_; +} } // namespace testing // Test using global rotation @@ -701,6 +717,58 @@ TEST_CASE("RotatedSPOs read and write parameters", "[wavefunction]") CHECK(full_var[5] == Approx(0.0)); } +// Test using global rotation +TEMPLATE_TEST_CASE("RotatedSPOs read and write parameters", "[wavefunction][template]", double, float) +{ + auto fake_spo = std::make_unique>(); + fake_spo->setOrbitalSetSize(4); + RotatedSPOsT rot("fake_rot", std::move(fake_spo)); + int nel = 2; + rot.buildOptVariables(nel); + + optimize::VariableSet vs; + rot.checkInVariablesExclusive(vs); + vs[0] = 0.1; + vs[1] = 0.15; + vs[2] = 0.2; + vs[3] = 0.25; + rot.resetParametersExclusive(vs); + + { + hdf_archive hout; + vs.writeToHDF("rot_vp.h5", hout); + + rot.writeVariationalParameters(hout); + } + + auto fake_spo2 = std::make_unique>(); + fake_spo2->setOrbitalSetSize(4); + + RotatedSPOsT rot2("fake_rot", std::move(fake_spo2)); + rot2.buildOptVariables(nel); + + optimize::VariableSet vs2; + rot2.checkInVariablesExclusive(vs2); + + hdf_archive hin; + vs2.readFromHDF("rot_vp.h5", hin); + rot2.readVariationalParameters(hin); + + opt_variables_type& var = testing::getMyVars(rot2); + CHECK(var[0] == Approx(vs[0])); + CHECK(var[1] == Approx(vs[1])); + CHECK(var[2] == Approx(vs[2])); + CHECK(var[3] == Approx(vs[3])); + + opt_variables_type& full_var = testing::getMyVarsFull(rot2); + CHECK(full_var[0] == Approx(vs[0])); + CHECK(full_var[1] == Approx(vs[1])); + CHECK(full_var[2] == Approx(vs[2])); + CHECK(full_var[3] == Approx(vs[3])); + CHECK(full_var[4] == Approx(0.0)); + CHECK(full_var[5] == Approx(0.0)); +} + // Test using history list. TEST_CASE("RotatedSPOs read and write parameters history", "[wavefunction]") {