Skip to content

Commit

Permalink
Implement SPOSetT template class
Browse files Browse the repository at this point in the history
Asses the initial effort to refactor SPOSet into templates
without consumers or tests.
Concretize friend class declaration
Define testing::getMyVars for SPOSetT
Add FakeSPOT class
Move SpinorSet to a templated class
Refactor FreeOrbital class
Base typed aliases on SPOSet<T> on OrbitalSetTraits<T>
Add FullRealType in SPOSet and RotatedSPOs
Add this in templated meta class
Add explicit function instantions for FreeOrbital
Add templated class SHOSetT
Add PWRealOrbitalSetT template class
Revert test_RotatedSPOs.cpp

Signed-off-by: Steven Hahn <[email protected]>
  • Loading branch information
williamfgc committed Oct 25, 2023
1 parent d5614b3 commit f466ca3
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/QMCWaveFunctions/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ set(WFBASE_SRCS
SpinorSetT.cpp)

if(NOT QMC_COMPLEX)
set(WFBASE_SRCS ${WFBASE_SRCS} RotatedSPOs.cpp)
set(WFBASE_SRCS ${WFBASE_SRCS} RotatedSPOs.cpp RotatedSPOsT.cpp)
endif(NOT QMC_COMPLEX)

if(QMC_COMPLEX)
set(WFBASE_SRCS ${WFBASE_SRCS} SpinorSet.cpp)
set(WFBASE_SRCS ${WFBASE_SRCS} SpinorSet.cpp SpinorSetT.cpp)
endif(QMC_COMPLEX)
########################
# build jastrows
Expand Down Expand Up @@ -149,7 +149,7 @@ if(OHMMS_DIM MATCHES 3)
if(QMC_COMPLEX)
set(FERMION_SRCS ${FERMION_SRCS} PlaneWave/PWOrbitalSet.cpp)
else()
set(FERMION_SRCS ${FERMION_SRCS} PlaneWave/PWRealOrbitalSet.cpp)
set(FERMION_SRCS ${FERMION_SRCS} PlaneWave/PWRealOrbitalSet.cpp PlaneWave/PWRealOrbitalSetT.cpp)
endif(QMC_COMPLEX)

if(NOT QMC_COMPLEX)
Expand Down
68 changes: 68 additions & 0 deletions src/QMCWaveFunctions/tests/test_RotatedSPOs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include "QMCWaveFunctions/WaveFunctionComponent.h"
#include "BsplineFactory/EinsplineSetBuilder.h"
#include "QMCWaveFunctions/RotatedSPOs.h"
#include "QMCWaveFunctions/RotatedSPOsT.h"
#include "QMCWaveFunctions/SPOSetT.h"
#include "checkMatrix.hpp"
#include "FakeSPO.h"
#include <ResourceCollection.h>
Expand Down Expand Up @@ -645,8 +647,22 @@ TEST_CASE("RotatedSPOs construct delta matrix", "[wavefunction]")
namespace testing
{
opt_variables_type& getMyVars(SPOSet& rot) { return rot.myVars; }
opt_variables_type& getMyVars(SPOSetT<float>& rot) { return rot.myVars; }
opt_variables_type& getMyVars(SPOSetT<double>& rot) { return rot.myVars; }
opt_variables_type& getMyVars(SPOSetT<std::complex<float>>& rot) { return rot.myVars; }
opt_variables_type& getMyVars(SPOSetT<std::complex<double>>& rot) { return rot.myVars; }
opt_variables_type& getMyVarsFull(RotatedSPOs& rot) { return rot.myVarsFull; }
opt_variables_type& getMyVarsFull(RotatedSPOsT<double>& rot) { return rot.myVarsFull; }
opt_variables_type& getMyVarsFull(RotatedSPOsT<float>& rot) { return rot.myVarsFull; }
std::vector<std::vector<QMCTraits::RealType>>& getHistoryParams(RotatedSPOs& rot) { return rot.history_params_; }
std::vector<std::vector<QMCTraits::RealType>>& getHistoryParams(RotatedSPOsT<double>& rot)
{
return rot.history_params_;
}
std::vector<std::vector<QMCTraits::RealType>>& getHistoryParams(RotatedSPOsT<float>& rot)
{
return rot.history_params_;
}
} // namespace testing

// Test using global rotation
Expand Down Expand Up @@ -701,6 +717,58 @@ TEST_CASE("RotatedSPOs read and write parameters", "[wavefunction]")
CHECK(full_var[5] == Approx(0.0));
}

// Test using global rotation
TEMPLATE_TEST_CASE("RotatedSPOs read and write parameters", "[wavefunction][template]", double, float)
{
auto fake_spo = std::make_unique<FakeSPOT<TestType>>();
fake_spo->setOrbitalSetSize(4);
RotatedSPOsT<TestType> rot("fake_rot", std::move(fake_spo));
int nel = 2;
rot.buildOptVariables(nel);

optimize::VariableSet vs;
rot.checkInVariablesExclusive(vs);
vs[0] = 0.1;
vs[1] = 0.15;
vs[2] = 0.2;
vs[3] = 0.25;
rot.resetParametersExclusive(vs);

{
hdf_archive hout;
vs.writeToHDF("rot_vp.h5", hout);

rot.writeVariationalParameters(hout);
}

auto fake_spo2 = std::make_unique<FakeSPOT<TestType>>();
fake_spo2->setOrbitalSetSize(4);

RotatedSPOsT<TestType> rot2("fake_rot", std::move(fake_spo2));
rot2.buildOptVariables(nel);

optimize::VariableSet vs2;
rot2.checkInVariablesExclusive(vs2);

hdf_archive hin;
vs2.readFromHDF("rot_vp.h5", hin);
rot2.readVariationalParameters(hin);

opt_variables_type& var = testing::getMyVars(rot2);
CHECK(var[0] == Approx(vs[0]));
CHECK(var[1] == Approx(vs[1]));
CHECK(var[2] == Approx(vs[2]));
CHECK(var[3] == Approx(vs[3]));

opt_variables_type& full_var = testing::getMyVarsFull(rot2);
CHECK(full_var[0] == Approx(vs[0]));
CHECK(full_var[1] == Approx(vs[1]));
CHECK(full_var[2] == Approx(vs[2]));
CHECK(full_var[3] == Approx(vs[3]));
CHECK(full_var[4] == Approx(0.0));
CHECK(full_var[5] == Approx(0.0));
}

// Test using history list.
TEST_CASE("RotatedSPOs read and write parameters history", "[wavefunction]")
{
Expand Down

0 comments on commit f466ca3

Please sign in to comment.