Skip to content

Commit

Permalink
Bugfix: removed QMC_COMPLEX conditions where no longer needed
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilipFackler committed Sep 27, 2023
1 parent b91b2d6 commit a0deb00
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 120 deletions.
5 changes: 3 additions & 2 deletions src/Particle/ParticleSetT.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
#ifndef QMCPLUSPLUS_PARTICLESETT_H
#define QMCPLUSPLUS_PARTICLESETT_H

#include <memory>

#include "DTModes.h"
#include "DynamicCoordinatesT.h"
#include "MCCoordsT.hpp"
Expand All @@ -38,6 +36,8 @@
#include "Walker.h"
#include "type_traits/template_types.hpp"

#include <memory>

namespace qmcplusplus
{
/// forward declarations
Expand Down Expand Up @@ -696,6 +696,7 @@ class ParticleSetT : public OhmmsElementBase
{
myTwist = t;
}

inline const SingleParticlePos&
getTwist() const
{
Expand Down
31 changes: 21 additions & 10 deletions src/QMCWaveFunctions/BsplineFactory/SplineR2RT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "SplineR2RT.h"

#include "CPU/BLAS.hpp"
#include "Concurrency/OpenMP.h"
#include "QMCWaveFunctions/BsplineFactory/contraction_helper.hpp"
#include "spline2/MultiBsplineEval.hpp"
Expand Down Expand Up @@ -125,17 +126,27 @@ SplineR2RT<ST, VT>::applyRotation(
std::copy_n(spl_coefs, coefs_tot_size, coef_copy_->begin());
}

// Apply rotation the dumb way b/c I can't get BLAS::gemm to work...
for (auto i = 0; i < BasisSetSize; i++) {
for (auto j = 0; j < this->OrbitalSetSize; j++) {
const auto cur_elem = Nsplines * i + j;
auto newval{0.};
for (auto k = 0; k < this->OrbitalSetSize; k++) {
const auto index = i * Nsplines + k;
newval += (*coef_copy_)[index] * rot_mat[k][j];
if constexpr (std::is_same_v<ST, RealType>) {
// Here, ST should be equal to ValueType, which will be double for R2R.
// Using BLAS to make things faster
BLAS::gemm('N', 'N', this->OrbitalSetSize, BasisSetSize,
this->OrbitalSetSize, ST(1.0), rot_mat.data(), this->OrbitalSetSize,
coef_copy_->data(), Nsplines, ST(0.0), spl_coefs, Nsplines);
}
else {
// Here, ST is float but ValueType is double for R2R. Due to issues with
// type conversions, just doing naive matrix multiplication in this case
// to not lose precision on rot_mat
for (IndexType i = 0; i < BasisSetSize; i++)
for (IndexType j = 0; j < this->OrbitalSetSize; j++) {
const auto cur_elem = Nsplines * i + j;
FullPrecValueType newval{0.};
for (IndexType k = 0; k < this->OrbitalSetSize; k++) {
const auto index = i * Nsplines + k;
newval += (*coef_copy_)[index] * rot_mat[k][j];
}
spl_coefs[cur_elem] = newval;
}
spl_coefs[cur_elem] = newval;
}
}
}

Expand Down
6 changes: 4 additions & 2 deletions src/QMCWaveFunctions/BsplineFactory/SplineR2RT.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,12 @@ class SplineR2RT : public BsplineSetT<VT>
using SplineType = typename bspline_traits<ST, 3>::SplineType;
using BCType = typename bspline_traits<ST, 3>::BCType;
using DataType = ST;
using RealType = typename SPOSetT<VT>::RealType;
using IndexType = typename SPOSetT<VT>::IndexType;
using FullPrecValueType = double;
using PointType = TinyVector<ST, 3>;
using SingleSplineType = UBspline_3d_d;

// types for evaluation results
using TT = typename BsplineSetT<VT>::ValueType;
using GGGVector = typename BsplineSetT<VT>::GGGVector;
Expand All @@ -55,8 +59,6 @@ class SplineR2RT : public BsplineSetT<VT>
using hContainer_type = VectorSoaContainer<ST, 6>;
using ghContainer_type = VectorSoaContainer<ST, 10>;

using RealType = typename SPOSetT<VT>::RealType;

private:
bool IsGamma;
///\f$GGt=G^t G \f$, transformation for tensor in LatticeUnit to
Expand Down
12 changes: 9 additions & 3 deletions src/QMCWaveFunctions/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,17 @@ if(OHMMS_DIM MATCHES 3)
endif(HAVE_EINSPLINE)

# plane wave SPO
set(FERMION_SRCS ${FERMION_SRCS} PlaneWave/PWBasis.cpp PlaneWave/PWBasisT.cpp PlaneWave/PWParameterSet.cpp PlaneWave/PWOrbitalBuilder.cpp)
set(FERMION_SRCS ${FERMION_SRCS}
PlaneWave/PWBasis.cpp
PlaneWave/PWBasisT.cpp
PlaneWave/PWOrbitalSetT.cpp
PlaneWave/PWRealOrbitalSetT.cpp
PlaneWave/PWParameterSet.cpp
PlaneWave/PWOrbitalBuilder.cpp)
if(QMC_COMPLEX)
set(FERMION_SRCS ${FERMION_SRCS} PlaneWave/PWOrbitalSet.cpp PlaneWave/PWOrbitalSetT.cpp)
set(FERMION_SRCS ${FERMION_SRCS} PlaneWave/PWOrbitalSet.cpp)
else()
set(FERMION_SRCS ${FERMION_SRCS} PlaneWave/PWRealOrbitalSet.cpp PlaneWave/PWRealOrbitalSetT.cpp)
set(FERMION_SRCS ${FERMION_SRCS} PlaneWave/PWRealOrbitalSet.cpp)
endif(QMC_COMPLEX)

if(NOT QMC_COMPLEX)
Expand Down
156 changes: 79 additions & 77 deletions src/QMCWaveFunctions/EinsplineSetBuilderT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -514,19 +514,19 @@ EinsplineSetBuilderT<T>::AnalyzeTwists2(
}

TargetPtcl.setTwist(superFracs[twist_num_]);
#ifndef QMC_COMPLEX
// Check to see if supercell twist is okay to use with real wave
// functions
for (int dim = 0; dim < OHMMS_DIM; dim++) {
double t = 2.0 * superFracs[twist_num_][dim];
if (std::abs(t - round(t)) > MatchingTol * 100) {
app_error()
<< "Cannot use this super twist with real wavefunctions.\n"
<< "Please recompile with QMC_COMPLEX=1.\n";
APP_ABORT("EinsplineSetBuilder::AnalyzeTwists2");
if constexpr (!IsComplex_t<T>{}()) {
// Check to see if supercell twist is okay to use with real wave
// functions
for (int dim = 0; dim < OHMMS_DIM; dim++) {
double t = 2.0 * superFracs[twist_num_][dim];
if (std::abs(t - round(t)) > MatchingTol * 100) {
app_error()
<< "Cannot use this super twist with real wavefunctions.\n"
<< "Please recompile with QMC_COMPLEX=1.\n";
APP_ABORT("EinsplineSetBuilder::AnalyzeTwists2");
}
}
}
#endif
// Now check to see that each supercell twist has the right twists
// to tile the primitive cell orbitals.
const int numTwistsNeeded = std::abs(det(TileMatrix));
Expand Down Expand Up @@ -574,78 +574,80 @@ EinsplineSetBuilderT<T>::AnalyzeTwists2(
IncludeTwists.push_back(superSets[twist_num_][i]);
// Now, find out which twists are distinct
DistinctTwists.clear();
#ifndef QMC_COMPLEX
std::vector<int> copyTwists;
for (int i = 0; i < IncludeTwists.size(); i++) {
int ti = IncludeTwists[i];
PosType twist_i = primcell_kpoints[ti];
bool distinct = true;
for (int j = i + 1; j < IncludeTwists.size(); j++) {
int tj = IncludeTwists[j];
PosType twist_j = primcell_kpoints[tj];
PosType sum = twist_i + twist_j;
PosType diff = twist_i - twist_j;
if (TwistPair(twist_i, twist_j))
distinct = false;
if constexpr (!IsComplex_t<T>{}()) {
std::vector<int> copyTwists;
for (int i = 0; i < IncludeTwists.size(); i++) {
int ti = IncludeTwists[i];
PosType twist_i = primcell_kpoints[ti];
bool distinct = true;
for (int j = i + 1; j < IncludeTwists.size(); j++) {
int tj = IncludeTwists[j];
PosType twist_j = primcell_kpoints[tj];
PosType sum = twist_i + twist_j;
PosType diff = twist_i - twist_j;
if (TwistPair(twist_i, twist_j))
distinct = false;
}
if (distinct)
DistinctTwists.push_back(ti);
else
copyTwists.push_back(ti);
}
if (distinct)
DistinctTwists.push_back(ti);
else
copyTwists.push_back(ti);
}
// Now determine which distinct twists require two copies
MakeTwoCopies.resize(DistinctTwists.size());
for (int i = 0; i < DistinctTwists.size(); i++) {
MakeTwoCopies[i] = false;
int ti = DistinctTwists[i];
PosType twist_i = primcell_kpoints[ti];
for (int j = 0; j < copyTwists.size(); j++) {
int tj = copyTwists[j];
PosType twist_j = primcell_kpoints[tj];
if (TwistPair(twist_i, twist_j))
MakeTwoCopies[i] = true;
// Now determine which distinct twists require two copies
MakeTwoCopies.resize(DistinctTwists.size());
for (int i = 0; i < DistinctTwists.size(); i++) {
MakeTwoCopies[i] = false;
int ti = DistinctTwists[i];
PosType twist_i = primcell_kpoints[ti];
for (int j = 0; j < copyTwists.size(); j++) {
int tj = copyTwists[j];
PosType twist_j = primcell_kpoints[tj];
if (TwistPair(twist_i, twist_j))
MakeTwoCopies[i] = true;
}
if (this->myComm->rank() == 0) {
std::array<char, 1000> buf;
int length = std::snprintf(buf.data(), buf.size(),
"Using %d copies of twist angle [%6.3f, %6.3f, %6.3f]\n",
MakeTwoCopies[i] ? 2 : 1, twist_i[0], twist_i[1],
twist_i[2]);
if (length < 0)
throw std::runtime_error("Error generating string");
app_log() << std::string_view(buf.data(), length);
app_log().flush();
}
}
if (this->myComm->rank() == 0) {
std::array<char, 1000> buf;
int length = std::snprintf(buf.data(), buf.size(),
"Using %d copies of twist angle [%6.3f, %6.3f, %6.3f]\n",
MakeTwoCopies[i] ? 2 : 1, twist_i[0], twist_i[1], twist_i[2]);
if (length < 0)
throw std::runtime_error("Error generating string");
app_log() << std::string_view(buf.data(), length);
app_log().flush();
// Find out if we can make real orbitals
use_real_splines_ = true;
for (int i = 0; i < DistinctTwists.size(); i++) {
int ti = DistinctTwists[i];
PosType twist = primcell_kpoints[ti];
for (int j = 0; j < OHMMS_DIM; j++)
if (std::abs(twist[j] - 0.0) > MatchingTol &&
std::abs(twist[j] - 0.5) > MatchingTol &&
std::abs(twist[j] + 0.5) > MatchingTol)
use_real_splines_ = false;
}
if (use_real_splines_ && (DistinctTwists.size() > 1)) {
app_log() << "***** Use of real orbitals is possible, but not "
"currently implemented\n"
<< " with more than one twist angle.\n";
use_real_splines_ = false;
}
if (use_real_splines_)
app_log() << "Using real splines.\n";
else
app_log() << "Using complex splines.\n";
}
// Find out if we can make real orbitals
use_real_splines_ = true;
for (int i = 0; i < DistinctTwists.size(); i++) {
int ti = DistinctTwists[i];
PosType twist = primcell_kpoints[ti];
for (int j = 0; j < OHMMS_DIM; j++)
if (std::abs(twist[j] - 0.0) > MatchingTol &&
std::abs(twist[j] - 0.5) > MatchingTol &&
std::abs(twist[j] + 0.5) > MatchingTol)
use_real_splines_ = false;
}
if (use_real_splines_ && (DistinctTwists.size() > 1)) {
app_log() << "***** Use of real orbitals is possible, but not "
"currently implemented\n"
<< " with more than one twist angle.\n";
else {
DistinctTwists.resize(IncludeTwists.size());
MakeTwoCopies.resize(IncludeTwists.size());
for (int i = 0; i < IncludeTwists.size(); i++) {
DistinctTwists[i] = IncludeTwists[i];
MakeTwoCopies[i] = false;
}
use_real_splines_ = false;
}
if (use_real_splines_)
app_log() << "Using real splines.\n";
else
app_log() << "Using complex splines.\n";
#else
DistinctTwists.resize(IncludeTwists.size());
MakeTwoCopies.resize(IncludeTwists.size());
for (int i = 0; i < IncludeTwists.size(); i++) {
DistinctTwists[i] = IncludeTwists[i];
MakeTwoCopies[i] = false;
}
use_real_splines_ = false;
#endif
}

template <typename T>
Expand Down
4 changes: 2 additions & 2 deletions src/QMCWaveFunctions/PlaneWave/PWOrbitalSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
/** @file PWOrbitalSet.h
* @brief Definition of member functions of Plane-wave basis set
*/
#ifndef QMCPLUSPLUS_PLANEWAVE_ORBITALSETT_BLAS_H
#define QMCPLUSPLUS_PLANEWAVE_ORBITALSETT_BLAS_H
#ifndef QMCPLUSPLUS_PLANEWAVE_ORBITALSET_BLAS_H
#define QMCPLUSPLUS_PLANEWAVE_ORBITALSET_BLAS_H

#include "QMCWaveFunctions/PlaneWave/PWBasis.h"
#include "QMCWaveFunctions/SPOSet.h"
Expand Down
4 changes: 2 additions & 2 deletions src/QMCWaveFunctions/PlaneWave/PWOrbitalSetT.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
/** @file PWOrbitalSetT.h
* @brief Definition of member functions of Plane-wave basis set
*/
#ifndef QMCPLUSPLUS_PLANEWAVE_ORBITALSET_BLAS_H
#define QMCPLUSPLUS_PLANEWAVE_ORBITALSET_BLAS_H
#ifndef QMCPLUSPLUS_PLANEWAVE_ORBITALSETT_BLAS_H
#define QMCPLUSPLUS_PLANEWAVE_ORBITALSETT_BLAS_H

#include "CPU/BLAS.hpp"
#include "QMCWaveFunctions/PlaneWave/PWBasisT.h"
Expand Down
Loading

0 comments on commit a0deb00

Please sign in to comment.