diff --git a/core/src/View/Kokkos_ViewTraits.hpp b/core/src/View/Kokkos_ViewTraits.hpp index 12136e8eb12..6d8cd51d51f 100644 --- a/core/src/View/Kokkos_ViewTraits.hpp +++ b/core/src/View/Kokkos_ViewTraits.hpp @@ -120,16 +120,34 @@ template struct AccessorFromViewTraits { using type = SpaceAwareAccessor>; + default_accessor>; }; template -struct AccessorFromViewTraits> { +struct AccessorFromViewTraits< + Traits, + std::enable_if_t> { using type = checked_reference_counted_accessor; }; +template +struct AccessorFromViewTraits< + Traits, + std::enable_if_t> { + using type = checked_reference_counted_atomic_accessor_relaxed< + typename Traits::value_type, typename Traits::memory_space>; +}; + +template +struct AccessorFromViewTraits< + Traits, + std::enable_if_t> { + using type = checked_atomic_accessor_relaxed; +}; + template using accessor_from_view_traits_t = typename AccessorFromViewTraits::type; diff --git a/core/src/View/MDSpan/Kokkos_MDSpan_Accessor.hpp b/core/src/View/MDSpan/Kokkos_MDSpan_Accessor.hpp index e857fb0687c..5ed44b52a79 100644 --- a/core/src/View/MDSpan/Kokkos_MDSpan_Accessor.hpp +++ b/core/src/View/MDSpan/Kokkos_MDSpan_Accessor.hpp @@ -234,6 +234,13 @@ class ReferenceCountedDataHandle { ReferenceCountedDataHandle(OtherElementType* ptr) : m_tracker(), m_handle(ptr) {} + template >> + ReferenceCountedDataHandle( + const ReferenceCountedDataHandle &other, OtherElementType* ptr) + : m_tracker(other.m_tracker), m_handle(ptr) {} + template >> @@ -247,12 +254,6 @@ class ReferenceCountedDataHandle { default; ReferenceCountedDataHandle& operator=(ReferenceCountedDataHandle&&) = default; - ReferenceCountedDataHandle with_offset(size_t offset) const { - auto ret = *this; - ret.m_handle += offset; - return ret; - } - pointer get() const noexcept { return m_handle; } explicit operator pointer() const noexcept { return m_handle; } @@ -265,6 +266,10 @@ class ReferenceCountedDataHandle { private: template friend class ReferenceCountedDataHandle; + + template + friend class ReferenceCountedAccessor; + SharedAllocationTracker m_tracker; pointer m_handle = nullptr; }; @@ -289,6 +294,12 @@ class ReferenceCountedDataHandle { ReferenceCountedDataHandle(OtherElementType* ptr) : m_tracker(), m_handle(ptr) {} + template >> + ReferenceCountedDataHandle(const ReferenceCountedDataHandle &other, OtherElementType* ptr) + : m_tracker(other.m_tracker), m_handle(ptr) {} + ReferenceCountedDataHandle(const ReferenceCountedDataHandle&) = default; ReferenceCountedDataHandle(ReferenceCountedDataHandle&&) noexcept = default; ReferenceCountedDataHandle& operator=(const ReferenceCountedDataHandle&) = @@ -302,12 +313,6 @@ class ReferenceCountedDataHandle { const ReferenceCountedDataHandle& other) : m_tracker(other.m_tracker), m_handle(other.m_handle) {} - ReferenceCountedDataHandle with_offset(size_t offset) const { - auto ret = *this; - ret.m_handle += offset; - return ret; - } - pointer get() const noexcept { return m_handle; } explicit operator pointer() const noexcept { return m_handle; } @@ -325,59 +330,75 @@ class ReferenceCountedDataHandle { pointer m_handle = nullptr; }; -template +template class ReferenceCountedAccessor { public: using element_type = ElementType; using data_handle_type = ReferenceCountedDataHandle; - using reference = typename data_handle_type::reference; + using reference = typename NestedAccessor::reference; using offset_policy = ReferenceCountedAccessor; constexpr ReferenceCountedAccessor() noexcept = default; - template >> + template < + class OtherElementType, class OtherNestedAccessor, + class = std::enable_if_t< + std::is_convertible_v && + std::is_constructible_v>> constexpr ReferenceCountedAccessor( - const ReferenceCountedAccessor&) {} - + const ReferenceCountedAccessor&) {} + template >> constexpr ReferenceCountedAccessor( const default_accessor&) {} - operator default_accessor() const { return {}; } + operator NestedAccessor() const { return m_nested_acc; } constexpr reference access(data_handle_type p, size_t i) const { - return p.get()[i]; + return m_nested_acc.access(p.get(), i); } constexpr data_handle_type offset(data_handle_type p, size_t i) const { - return p.with_offset(i); + return data_handle_type(p, m_nested_acc.offset(p.get(), i)); } + + private: +#ifdef _MDSPAN_NO_UNIQUE_ADDRESS + _MDSPAN_NO_UNIQUE_ADDRESS +#else + [[no_unique_address]] +#endif + NestedAccessor m_nested_acc; }; -template -class ReferenceCountedAccessor { +template +class ReferenceCountedAccessor { public: using element_type = ElementType; using data_handle_type = ReferenceCountedDataHandle; - using reference = typename data_handle_type::reference; + using reference = typename NestedAccessor::reference; using offset_policy = ReferenceCountedAccessor; constexpr ReferenceCountedAccessor() noexcept = default; - template + template >> constexpr ReferenceCountedAccessor( - const ReferenceCountedAccessor&) {} + const ReferenceCountedAccessor&) {} - template >> + template < + class OtherElementType, class OtherSpace, class OtherNestedAccessor, + class = std::enable_if_t< + std::is_convertible_v && + std::is_constructible_v>> constexpr ReferenceCountedAccessor( - const ReferenceCountedAccessor&) {} + const ReferenceCountedAccessor&) {} template { constexpr ReferenceCountedAccessor( const default_accessor&) {} - operator default_accessor() const { return {}; } + operator NestedAccessor() const { return m_nested_acc; } constexpr reference access(data_handle_type p, size_t i) const { - return p.get()[i]; + return m_nested_acc.access(p.get(), i); } constexpr data_handle_type offset(data_handle_type p, size_t i) const { - return p.with_offset(i); + return data_handle_type(p, m_nested_acc.offset(p.get(), i)); } + + private: +#ifdef _MDSPAN_NO_UNIQUE_ADDRESS + _MDSPAN_NO_UNIQUE_ADDRESS +#else + [[no_unique_address]] +#endif + NestedAccessor m_nested_acc; }; template using checked_reference_counted_accessor = SpaceAwareAccessor>; + ReferenceCountedAccessor>>; + +template +using checked_atomic_accessor_relaxed = SpaceAwareAccessor< + MemorySpace, AtomicAccessorRelaxed>; + +template +using checked_reference_counted_atomic_accessor_relaxed = SpaceAwareAccessor< + MemorySpace, ReferenceCountedAccessor>>; } // namespace Impl } // namespace Kokkos diff --git a/core/unit_test/view/TestBasicViewMDSpanConversion.cpp b/core/unit_test/view/TestBasicViewMDSpanConversion.cpp index 4866297ff51..75392ff7772 100644 --- a/core/unit_test/view/TestBasicViewMDSpanConversion.cpp +++ b/core/unit_test/view/TestBasicViewMDSpanConversion.cpp @@ -42,4 +42,10 @@ static_assert( Kokkos::Experimental::layout_right_padded<>, Kokkos::Impl::checked_reference_counted_accessor< const long long, Kokkos::HostSpace>>>); + +using test_atomic_view = Kokkos::View>; +static_assert(std::is_same_v< + decltype(std::declval()(std::declval())), + desul::AtomicRef>); #endif diff --git a/tpls/desul/include/desul/atomics/Atomic_Ref.hpp b/tpls/desul/include/desul/atomics/Atomic_Ref.hpp index 145c2457a2a..679c054da74 100644 --- a/tpls/desul/include/desul/atomics/Atomic_Ref.hpp +++ b/tpls/desul/include/desul/atomics/Atomic_Ref.hpp @@ -79,6 +79,10 @@ class AtomicRef { DESUL_IMPL_DEFINE_ATOMIC_FETCH_OP(xor) DESUL_IMPL_DEFINE_ATOMIC_COMPOUND_ASSIGNMENT_OP(^=, xor) DESUL_IMPL_DEFINE_ATOMIC_FETCH_OP(nand) + DESUL_IMPL_DEFINE_ATOMIC_FETCH_OP(lshift) + DESUL_IMPL_DEFINE_ATOMIC_COMPOUND_ASSIGNMENT_OP(<<=, lshift) + DESUL_IMPL_DEFINE_ATOMIC_FETCH_OP(rshift) + DESUL_IMPL_DEFINE_ATOMIC_COMPOUND_ASSIGNMENT_OP(>>=, rshift) #undef DESUL_IMPL_DEFINE_ATOMIC_COMPOUND_ASSIGNMENT_OP #undef DESUL_IMPL_DEFINE_ATOMIC_FETCH_OP