diff --git a/src/Numerics/OhmmsPETE/OhmmsMatrix.h b/src/Numerics/OhmmsPETE/OhmmsMatrix.h index 4aa963d1b..0d74c1308 100644 --- a/src/Numerics/OhmmsPETE/OhmmsMatrix.h +++ b/src/Numerics/OhmmsPETE/OhmmsMatrix.h @@ -22,7 +22,7 @@ namespace qmcplusplus { -template, unsigned MemType = MemorySpace::HOST> +template> class Matrix { public: @@ -30,10 +30,10 @@ class Matrix typedef T value_type; typedef T* pointer; typedef const T* const_pointer; - typedef Vector Container_t; + typedef Vector Container_t; typedef typename Container_t::size_type size_type; typedef typename Container_t::iterator iterator; - typedef Matrix This_t; + typedef Matrix This_t; Matrix() : D1(0), D2(0), TotSize(0) {} // Default Constructor initializes to zero. @@ -56,7 +56,7 @@ class Matrix Matrix(const This_t& rhs) { resize(rhs.D1, rhs.D2); - if (MemType == MemorySpace::HOST) + if (allocator_traits::is_host_accessible) assign(*this, rhs); } @@ -107,16 +107,16 @@ class Matrix X.attachReference(ref, TotSize); } + template> inline void add(size_type n) // you can add rows: adding columns are forbidden { - static_assert(MemType == MemorySpace::HOST, "Matrix::add MemType must be MemorySpace::HOST"); X.insert(X.end(), n * D2, T()); D1 += n; } + template> inline void copy(const This_t& rhs) { - static_assert(MemType == MemorySpace::HOST, "Matrix::copy MemType must be MemorySpace::HOST"); resize(rhs.D1, rhs.D2); assign(*this, rhs); } @@ -124,21 +124,15 @@ class Matrix // Assignment Operators inline This_t& operator=(const This_t& rhs) { - static_assert(MemType == MemorySpace::HOST, "Matrix::operator= MemType must be MemorySpace::HOST"); resize(rhs.D1, rhs.D2); - return assign(*this, rhs); - } - - inline const This_t& operator=(const This_t& rhs) const - { - static_assert(MemType == MemorySpace::HOST, "Matrix::operator= MemType must be MemorySpace::HOST"); - return assign(*this, rhs); + if (allocator_traits::is_host_accessible) + assign(*this, rhs); + return *this; } - template + template> This_t& operator=(const RHS& rhs) { - static_assert(MemType == MemorySpace::HOST, "Matrix::operator= MemType must be MemorySpace::HOST"); return assign(*this, rhs); } @@ -172,35 +166,36 @@ class Matrix /// returns a pointer of i-th row, g++ iterator problem inline Type_t* operator[](size_type i) { return X.data() + i * D2; } + template> inline Type_t& operator()(size_type i) { - static_assert(MemType == MemorySpace::HOST, "Matrix::operator() MemType must be MemorySpace::HOST"); return X[i]; } + // returns the i-th value in D1*D2 vector + template> inline Type_t operator()(size_type i) const { - static_assert(MemType == MemorySpace::HOST, "Matrix::operator() MemType must be MemorySpace::HOST"); return X[i]; } // returns val(i,j) + template> inline Type_t& operator()(size_type i, size_type j) { - static_assert(MemType == MemorySpace::HOST, "Matrix::operator() MemType must be MemorySpace::HOST"); return X[i * D2 + j]; } // returns val(i,j) + template> inline const Type_t& operator()(size_type i, size_type j) const { - static_assert(MemType == MemorySpace::HOST, "Matrix::operator() MemType must be MemorySpace::HOST"); return X[i * D2 + j]; } + template> inline void swap_rows(int r1, int r2) { - static_assert(MemType == MemorySpace::HOST, "Matrix::operator() MemType must be MemorySpace::HOST"); for (int col = 0; col < D2; col++) { Type_t tmp = (*this)(r1, col); @@ -209,9 +204,9 @@ class Matrix } } + template> inline void swap_cols(int c1, int c2) { - static_assert(MemType == MemorySpace::HOST, "Matrix::operator() MemType must be MemorySpace::HOST"); for (int row = 0; row < D1; row++) { Type_t tmp = (*this)(row, c1); @@ -220,27 +215,23 @@ class Matrix } } - - template + template> inline void replaceRow(IT first, size_type i) { - static_assert(MemType == MemorySpace::HOST, "Matrix::operator() MemType must be MemorySpace::HOST"); std::copy(first, first + D2, X.begin() + i * D2); } - template + template> inline void replaceColumn(IT first, size_type j) { - static_assert(MemType == MemorySpace::HOST, "Matrix::operator() MemType must be MemorySpace::HOST"); typename Container_t::iterator ii(X.begin() + j); for (int i = 0; i < D1; i++, ii += D2) *ii = *first++; } - template + template> inline void add2Column(IT first, size_type j) { - static_assert(MemType == MemorySpace::HOST, "Matrix::operator() MemType must be MemorySpace::HOST"); typename Container_t::iterator ii(X.begin() + j); for (int i = 0; i < D1; i++, ii += D2) *ii += *first++; @@ -253,10 +244,9 @@ class Matrix * \param i0 starting row where the copying is done * \param j0 starting column where the copying is done */ - template + template> inline void add(const T1* sub, size_type d1, size_type d2, size_type i0, size_type j0) { - static_assert(MemType == MemorySpace::HOST, "Matrix::operator() MemType must be MemorySpace::HOST"); int ii = 0; for (int i = 0; i < d1; i++) { @@ -268,10 +258,9 @@ class Matrix } } - template + template> inline void add(const T1* sub, size_type d1, size_type d2, size_type i0, size_type j0, const T& phi) { - static_assert(MemType == MemorySpace::HOST, "Matrix::operator() MemType must be MemorySpace::HOST"); size_type ii = 0; for (size_type i = 0; i < d1; i++) { @@ -283,10 +272,9 @@ class Matrix } } - template + template> inline void add(const SubMat_t& sub, unsigned int i0, unsigned int j0) { - static_assert(MemType == MemorySpace::HOST, "Matrix::operator() MemType must be MemorySpace::HOST"); size_type ii = 0; for (size_type i = 0; i < sub.rows(); i++) { @@ -298,9 +286,9 @@ class Matrix } } + template> inline void add(const This_t& sub, unsigned int i0, unsigned int j0) { - static_assert(MemType == MemorySpace::HOST, "Matrix::operator() MemType must be MemorySpace::HOST"); size_type ii = 0; for (size_type i = 0; i < sub.rows(); i++) { @@ -312,18 +300,16 @@ class Matrix } } - template + template> inline Msg& putMessage(Msg& m) { - static_assert(MemType == MemorySpace::HOST, "Matrix::operator() MemType must be MemorySpace::HOST"); m.Pack(X.data(), D1 * D2); return m; } - template + template> inline Msg& getMessage(Msg& m) { - static_assert(MemType == MemorySpace::HOST, "Matrix::operator() MemType must be MemorySpace::HOST"); m.Unpack(X.data(), D1 * D2); return m; } diff --git a/src/Numerics/OhmmsPETE/OhmmsVector.h b/src/Numerics/OhmmsPETE/OhmmsVector.h index 8972344f4..cee6edd2b 100644 --- a/src/Numerics/OhmmsPETE/OhmmsVector.h +++ b/src/Numerics/OhmmsPETE/OhmmsVector.h @@ -25,11 +25,11 @@ #include #include #include "Numerics/PETE/PETE.h" -#include "Utilities/SIMD/MemorySpace.hpp" +#include "Utilities/SIMD/allocator_traits.hpp" namespace qmcplusplus { -template, unsigned MemType = MemorySpace::HOST> +template> class Vector { public: @@ -48,7 +48,7 @@ class Vector if (n) { resize_impl(n); - if (MemType == MemorySpace::HOST) + if (allocator_traits::is_host_accessible) std::fill_n(X, n, val); } } @@ -60,28 +60,26 @@ class Vector Vector(const Vector& rhs) : nLocal(rhs.nLocal), nAllocated(0), X(nullptr) { resize_impl(rhs.nLocal); - if (MemType == MemorySpace::HOST) + if (allocator_traits::is_host_accessible) std::copy_n(rhs.data(), nLocal, X); } // default assignment operator inline Vector& operator=(const Vector& rhs) { - static_assert(MemType == MemorySpace::HOST, "Vector::operator= MemType must be MemorySpace::HOST"); if (this == &rhs) return *this; if (nLocal != rhs.nLocal) resize(rhs.nLocal); - std::copy_n(rhs.data(), nLocal, X); + if (allocator_traits::is_host_accessible) + std::copy_n(rhs.data(), nLocal, X); return *this; } // assignment operator from anther Vector class - template + template> inline Vector& operator=(const Vector& rhs) { - static_assert(MemType == MemorySpace::HOST, - "Vector::operator= the MemType of both sides must be MemorySpace::HOST"); if (std::is_convertible::value) { if (nLocal != rhs.nLocal) @@ -92,10 +90,9 @@ class Vector } // assigment operator to enable PETE - template + template> inline Vector& operator=(const RHS& rhs) { - static_assert(MemType == MemorySpace::HOST, "Vector::operator= MemType must be MemorySpace::HOST"); assign(*this, rhs); return *this; } @@ -128,20 +125,26 @@ class Vector static_assert(std::is_same::value, "Vector and Alloc data types must agree!"); if (nLocal > nAllocated) throw std::runtime_error("Resize not allowed on Vector constructed by initialized memory."); - if (n > nAllocated) + if(allocator_traits::is_host_accessible) { - resize_impl(n); - if (MemType == MemorySpace::HOST) + if (n > nAllocated) + { + resize_impl(n); std::fill_n(X, n, val); + } + else + { + if (n > nLocal) std::fill_n(X + nLocal, n - nLocal, val); + nLocal = n; + } } - else if (n > nLocal) + else { - if (MemType == MemorySpace::HOST) - std::fill_n(X + nLocal, n - nLocal, val); - nLocal = n; + if (n > nAllocated) + resize_impl(n); + else + nLocal = n; } - else - nLocal = n; return; } @@ -161,15 +164,15 @@ class Vector } // Get and Set Operations + template> inline Type_t& operator[](size_t i) { - static_assert(MemType == MemorySpace::HOST, "Vector::operator[] MemType must be MemorySpace::HOST"); return X[i]; } + template> inline const Type_t& operator[](size_t i) const { - static_assert(MemType == MemorySpace::HOST, "Vector::operator[] MemType must be MemorySpace::HOST"); return X[i]; } diff --git a/src/Utilities/SIMD/MemorySpace.hpp b/src/Utilities/SIMD/allocator_traits.hpp similarity index 50% rename from src/Utilities/SIMD/MemorySpace.hpp rename to src/Utilities/SIMD/allocator_traits.hpp index 80d009b1d..2bdc594a4 100644 --- a/src/Utilities/SIMD/MemorySpace.hpp +++ b/src/Utilities/SIMD/allocator_traits.hpp @@ -8,22 +8,27 @@ // // File created by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory ////////////////////////////////////////////////////////////////////////////////////// -// -*- C++ -*- -/** @file MemorySpace.hpp - */ -#ifndef QMCPLUSPLUS_MEMORYSPACE_H -#define QMCPLUSPLUS_MEMORYSPACE_H + + +#ifndef QMCPLUSPLUS_ACCESS_TRAITS_H +#define QMCPLUSPLUS_ACCESS_TRAITS_H namespace qmcplusplus { -struct MemorySpace +/** template class defines whether the memory allocated by the allocator is host accessible + */ +template +struct allocator_traits { - enum - { - HOST = 0, - CUDA - }; + const static bool is_host_accessible = true; }; + +template +using IsHostSafe = typename std::enable_if::is_host_accessible>::type; + +template +using IsNotHostSafe = typename std::enable_if::is_host_accessible>::type; + } // namespace qmcplusplus -#endif +#endif // QMCPLUSPLUS_ACCESS_TRAITS_H