Skip to content

Commit

Permalink
more fixes to accessors and add changes from desul/desul#129
Browse files Browse the repository at this point in the history
  • Loading branch information
nmm0 committed Jul 11, 2024
1 parent 3462bd7 commit ac3f520
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 37 deletions.
22 changes: 20 additions & 2 deletions core/src/View/Kokkos_ViewTraits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,34 @@ template <class Traits, class Enabled = void>
struct AccessorFromViewTraits {
using type =
SpaceAwareAccessor<typename Traits::memory_space,
Kokkos::default_accessor<typename Traits::value_type>>;
default_accessor<typename Traits::value_type>>;
};

template <class Traits>
struct AccessorFromViewTraits<Traits, std::enable_if_t<Traits::is_managed>> {
struct AccessorFromViewTraits<
Traits,
std::enable_if_t<Traits::is_managed && !Traits::memory_traits::is_atomic>> {
using type =
checked_reference_counted_accessor<typename Traits::value_type,
typename Traits::memory_space>;
};

template <class Traits>
struct AccessorFromViewTraits<
Traits,
std::enable_if_t<Traits::is_managed && Traits::memory_traits::is_atomic>> {
using type = checked_reference_counted_atomic_accessor_relaxed<
typename Traits::value_type, typename Traits::memory_space>;
};

template <class Traits>
struct AccessorFromViewTraits<
Traits,
std::enable_if_t<!Traits::is_managed && Traits::memory_traits::is_atomic>> {
using type = checked_atomic_accessor_relaxed<typename Traits::value_type,
typename Traits::memory_space>;
};

template <class Traits>
using accessor_from_view_traits_t = typename AccessorFromViewTraits<Traits>::type;

Expand Down
111 changes: 76 additions & 35 deletions core/src/View/MDSpan/Kokkos_MDSpan_Accessor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,13 @@ class ReferenceCountedDataHandle {
ReferenceCountedDataHandle(OtherElementType* ptr)
: m_tracker(), m_handle(ptr) {}

template <class OtherElementType,
class = std::enable_if_t<std::is_convertible_v<
OtherElementType (*)[], value_type (*)[]>>>
ReferenceCountedDataHandle(
const ReferenceCountedDataHandle &other, OtherElementType* ptr)
: m_tracker(other.m_tracker), m_handle(ptr) {}

template <class OtherElementType,
class = std::enable_if_t<std::is_convertible_v<
OtherElementType (*)[], value_type (*)[]>>>
Expand All @@ -247,12 +254,6 @@ class ReferenceCountedDataHandle {
default;
ReferenceCountedDataHandle& operator=(ReferenceCountedDataHandle&&) = default;

ReferenceCountedDataHandle with_offset(size_t offset) const {
auto ret = *this;
ret.m_handle += offset;
return ret;
}

pointer get() const noexcept { return m_handle; }
explicit operator pointer() const noexcept { return m_handle; }

Expand All @@ -265,6 +266,10 @@ class ReferenceCountedDataHandle {
private:
template <class OtherElementType, class OtherSpace>
friend class ReferenceCountedDataHandle;

template <class OtherElementType, class OtherSpace, class NestedAccessor>
friend class ReferenceCountedAccessor;

SharedAllocationTracker m_tracker;
pointer m_handle = nullptr;
};
Expand All @@ -289,6 +294,12 @@ class ReferenceCountedDataHandle<ElementType, AnonymousSpace> {
ReferenceCountedDataHandle(OtherElementType* ptr)
: m_tracker(), m_handle(ptr) {}

template <class OtherElementType,
class = std::enable_if_t<std::is_convertible_v<
OtherElementType (*)[], value_type (*)[]>>>
ReferenceCountedDataHandle(const ReferenceCountedDataHandle &other, OtherElementType* ptr)
: m_tracker(other.m_tracker), m_handle(ptr) {}

ReferenceCountedDataHandle(const ReferenceCountedDataHandle&) = default;
ReferenceCountedDataHandle(ReferenceCountedDataHandle&&) noexcept = default;
ReferenceCountedDataHandle& operator=(const ReferenceCountedDataHandle&) =
Expand All @@ -302,12 +313,6 @@ class ReferenceCountedDataHandle<ElementType, AnonymousSpace> {
const ReferenceCountedDataHandle<OtherElementType, OtherSpace>& other)
: m_tracker(other.m_tracker), m_handle(other.m_handle) {}

ReferenceCountedDataHandle with_offset(size_t offset) const {
auto ret = *this;
ret.m_handle += offset;
return ret;
}

pointer get() const noexcept { return m_handle; }
explicit operator pointer() const noexcept { return m_handle; }

Expand All @@ -325,81 +330,117 @@ class ReferenceCountedDataHandle<ElementType, AnonymousSpace> {
pointer m_handle = nullptr;
};

template <class ElementType, class MemorySpace>
template <class ElementType, class MemorySpace, class NestedAccessor>
class ReferenceCountedAccessor {
public:
using element_type = ElementType;
using data_handle_type = ReferenceCountedDataHandle<ElementType, MemorySpace>;
using reference = typename data_handle_type::reference;
using reference = typename NestedAccessor::reference;
using offset_policy = ReferenceCountedAccessor;

constexpr ReferenceCountedAccessor() noexcept = default;

template <class OtherElementType,
class = std::enable_if_t<std::is_convertible_v<
OtherElementType (*)[], element_type (*)[]>>>
template <
class OtherElementType, class OtherNestedAccessor,
class = std::enable_if_t<
std::is_convertible_v<OtherElementType (*)[], element_type (*)[]> &&
std::is_constructible_v<NestedAccessor, OtherNestedAccessor>>>
constexpr ReferenceCountedAccessor(
const ReferenceCountedAccessor<OtherElementType, MemorySpace>&) {}

const ReferenceCountedAccessor<OtherElementType, MemorySpace,
OtherNestedAccessor>&) {}

template <class OtherElementType,
class = std::enable_if_t<std::is_convertible_v<
OtherElementType (*)[], element_type (*)[]>>>
constexpr ReferenceCountedAccessor(
const default_accessor<OtherElementType>&) {}

operator default_accessor<element_type>() const { return {}; }
operator NestedAccessor() const { return m_nested_acc; }

constexpr reference access(data_handle_type p, size_t i) const {
return p.get()[i];
return m_nested_acc.access(p.get(), i);
}

constexpr data_handle_type offset(data_handle_type p, size_t i) const {
return p.with_offset(i);
return data_handle_type(p, m_nested_acc.offset(p.get(), i));
}

private:
#ifdef _MDSPAN_NO_UNIQUE_ADDRESS
_MDSPAN_NO_UNIQUE_ADDRESS
#else
[[no_unique_address]]
#endif
NestedAccessor m_nested_acc;
};

template <class ElementType>
class ReferenceCountedAccessor<ElementType, AnonymousSpace> {
template <class ElementType, class NestedAccessor>
class ReferenceCountedAccessor<ElementType, AnonymousSpace, NestedAccessor> {
public:
using element_type = ElementType;
using data_handle_type =
ReferenceCountedDataHandle<ElementType, AnonymousSpace>;
using reference = typename data_handle_type::reference;
using reference = typename NestedAccessor::reference;
using offset_policy = ReferenceCountedAccessor;

constexpr ReferenceCountedAccessor() noexcept = default;

template <class OtherSpace>
template <class OtherSpace, class OtherNestedAccessor,
class = std::enable_if_t<
std::is_constructible_v<NestedAccessor, OtherNestedAccessor>>>
constexpr ReferenceCountedAccessor(
const ReferenceCountedAccessor<ElementType, OtherSpace>&) {}
const ReferenceCountedAccessor<ElementType, OtherSpace,
OtherNestedAccessor>&) {}

template <class OtherElementType, class OtherSpace,
class = std::enable_if_t<std::is_convertible_v<
OtherElementType (*)[], element_type (*)[]>>>
template <
class OtherElementType, class OtherSpace, class OtherNestedAccessor,
class = std::enable_if_t<
std::is_convertible_v<OtherElementType (*)[], element_type (*)[]> &&
std::is_constructible_v<NestedAccessor, OtherNestedAccessor>>>
constexpr ReferenceCountedAccessor(
const ReferenceCountedAccessor<OtherElementType, OtherSpace>&) {}
const ReferenceCountedAccessor<OtherElementType, OtherSpace, OtherNestedAccessor>&) {}

template <class OtherElementType,
class = std::enable_if_t<std::is_convertible_v<
OtherElementType (*)[], element_type (*)[]>>>
constexpr ReferenceCountedAccessor(
const default_accessor<OtherElementType>&) {}

operator default_accessor<element_type>() const { return {}; }
operator NestedAccessor() const { return m_nested_acc; }

constexpr reference access(data_handle_type p, size_t i) const {
return p.get()[i];
return m_nested_acc.access(p.get(), i);
}

constexpr data_handle_type offset(data_handle_type p, size_t i) const {
return p.with_offset(i);
return data_handle_type(p, m_nested_acc.offset(p.get(), i));
}

private:
#ifdef _MDSPAN_NO_UNIQUE_ADDRESS
_MDSPAN_NO_UNIQUE_ADDRESS
#else
[[no_unique_address]]
#endif
NestedAccessor m_nested_acc;
};

template <class ElementType, class MemorySpace>
using checked_reference_counted_accessor =
SpaceAwareAccessor<MemorySpace,
ReferenceCountedAccessor<ElementType, MemorySpace>>;
ReferenceCountedAccessor<ElementType, MemorySpace,
default_accessor<ElementType>>>;

template <class ElementType, class MemorySpace,
class MemoryScope = desul::MemoryScopeDevice>
using checked_atomic_accessor_relaxed = SpaceAwareAccessor<
MemorySpace, AtomicAccessorRelaxed<ElementType>>;

template <class ElementType, class MemorySpace,
class MemoryScope = desul::MemoryScopeDevice>
using checked_reference_counted_atomic_accessor_relaxed = SpaceAwareAccessor<
MemorySpace, ReferenceCountedAccessor<ElementType, MemorySpace,
AtomicAccessorRelaxed<ElementType>>>;

} // namespace Impl
} // namespace Kokkos
Expand Down
6 changes: 6 additions & 0 deletions core/unit_test/view/TestBasicViewMDSpanConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,10 @@ static_assert(
Kokkos::Experimental::layout_right_padded<>,
Kokkos::Impl::checked_reference_counted_accessor<
const long long, Kokkos::HostSpace>>>);

using test_atomic_view = Kokkos::View<double *, Kokkos::Serial, Kokkos::MemoryTraits<Kokkos::Atomic>>;
static_assert(std::is_same_v<
decltype(std::declval<test_atomic_view>()(std::declval<int>())),
desul::AtomicRef<double, desul::MemoryOrderRelaxed,
desul::MemoryScopeDevice>>);
#endif
4 changes: 4 additions & 0 deletions tpls/desul/include/desul/atomics/Atomic_Ref.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ class AtomicRef {
DESUL_IMPL_DEFINE_ATOMIC_FETCH_OP(xor)
DESUL_IMPL_DEFINE_ATOMIC_COMPOUND_ASSIGNMENT_OP(^=, xor)
DESUL_IMPL_DEFINE_ATOMIC_FETCH_OP(nand)
DESUL_IMPL_DEFINE_ATOMIC_FETCH_OP(lshift)
DESUL_IMPL_DEFINE_ATOMIC_COMPOUND_ASSIGNMENT_OP(<<=, lshift)
DESUL_IMPL_DEFINE_ATOMIC_FETCH_OP(rshift)
DESUL_IMPL_DEFINE_ATOMIC_COMPOUND_ASSIGNMENT_OP(>>=, rshift)

#undef DESUL_IMPL_DEFINE_ATOMIC_COMPOUND_ASSIGNMENT_OP
#undef DESUL_IMPL_DEFINE_ATOMIC_FETCH_OP
Expand Down

0 comments on commit ac3f520

Please sign in to comment.