Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MDSpan issues expose by Kokkos View refactor #358

Merged
merged 5 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions include/experimental/__p2630_bits/submdspan_extents.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include <tuple>
#include <complex>

#include "strided_slice.hpp"
namespace MDSPAN_IMPL_STANDARD_NAMESPACE {
Expand Down Expand Up @@ -52,6 +53,33 @@ template <class OffsetType, class ExtentType, class StrideType>
struct is_strided_slice<
strided_slice<OffsetType, ExtentType, StrideType>> : std::true_type {};

// Helper for identifying valid pair like things
template <class T, class IndexType> struct index_pair_like : std::false_type {};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The feature macro for complex get support is called __cpp_lib_tuple_like so maybe

Suggested change
template <class T, class IndexType> struct index_pair_like : std::false_type {};
template <class T, class IndexType> struct index_tuple_like : std::false_type {};

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this has nothing to do with get: https://eel.is/c++draft/views.multidim#mdspan.syn
We are using an exposition only concept called index-pair-like and if you follow all the stuff down to their roots, this thing here basically implements that concept in form of a type_trait.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason why I looked at get was that that's what is used in first_of which now uses this concept as requirement.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed that by adding special versions for complex.


template <class IdxT1, class IdxT2, class IndexType>
struct index_pair_like<std::pair<IdxT1, IdxT2>, IndexType> {
static constexpr bool value = std::is_convertible_v<IdxT1, IndexType> &&
std::is_convertible_v<IdxT2, IndexType>;
};

template <class IdxT1, class IdxT2, class IndexType>
struct index_pair_like<std::tuple<IdxT1, IdxT2>, IndexType> {
static constexpr bool value = std::is_convertible_v<IdxT1, IndexType> &&
std::is_convertible_v<IdxT2, IndexType>;
};

template <class IdxT, class IndexType>
struct index_pair_like<std::complex<IdxT>, IndexType> {
static constexpr bool value = std::is_convertible_v<IdxT, IndexType>;
};
dalg24 marked this conversation as resolved.
Show resolved Hide resolved

template <class IdxT, class IndexType>
struct index_pair_like<std::array<IdxT, 2>, IndexType> {
static constexpr bool value = std::is_convertible_v<IdxT, IndexType>;
};

// FIXME: we actually need to pass IndexType into all of these
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is left to do?


// first_of(slice): getting begin of slice specifier range
MDSPAN_TEMPLATE_REQUIRES(
class Integral,
Expand All @@ -70,13 +98,19 @@ first_of(const ::MDSPAN_IMPL_STANDARD_NAMESPACE::full_extent_t &) {

MDSPAN_TEMPLATE_REQUIRES(
class Slice,
/* requires */(std::is_convertible_v<Slice, std::tuple<size_t, size_t>>)
/* requires */(index_pair_like<Slice, size_t>::value)
)
MDSPAN_INLINE_FUNCTION
constexpr auto first_of(const Slice &i) {
return std::get<0>(i);
}

template<class T>
MDSPAN_INLINE_FUNCTION
constexpr auto first_of(const std::complex<T> &i) {
return i.real();
}

template <class OffsetType, class ExtentType, class StrideType>
MDSPAN_INLINE_FUNCTION
constexpr OffsetType
Expand All @@ -100,14 +134,20 @@ constexpr Integral

MDSPAN_TEMPLATE_REQUIRES(
size_t k, class Extents, class Slice,
/* requires */(std::is_convertible_v<Slice, std::tuple<size_t, size_t>>)
/* requires */(index_pair_like<Slice, size_t>::value)
)
MDSPAN_INLINE_FUNCTION
constexpr auto last_of(std::integral_constant<size_t, k>, const Extents &,
const Slice &i) {
return std::get<1>(i);
}

template<size_t k, class Extents, class T>
MDSPAN_INLINE_FUNCTION
constexpr auto last_of(std::integral_constant<size_t, k>, const Extents &, const std::complex<T> &i) {
return i.imag();
}

// Suppress spurious warning with NVCC about no return statement.
// This is a known issue in NVCC and NVC++
// Depending on the CUDA and GCC version we need both the builtin
Expand Down
3 changes: 1 addition & 2 deletions include/experimental/__p2630_bits/submdspan_mapping.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,7 @@ template<class SliceSpecifier, class IndexType>
struct is_range_slice {
constexpr static bool value =
std::is_same_v<SliceSpecifier, full_extent_t> ||
std::is_convertible_v<SliceSpecifier,
std::tuple<IndexType, IndexType>>;
index_pair_like<SliceSpecifier, IndexType>::value;
};

template<class SliceSpecifier, class IndexType>
Expand Down
9 changes: 5 additions & 4 deletions include/experimental/__p2642_bits/layout_padded.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ struct padded_extent {
using static_array_type = typename static_array_type_for_padded_extent<
padding_value, _Extents, _ExtentToPadIdx, _Extents::rank()>::type;

MDSPAN_INLINE_FUNCTION
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you actually need this or did you just add for consistency?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

static_value is called inside a couple constructors and is_always_exhaustive all of which are device callable. The "INLINE" is just for consistency, at some point we can go wholesale through this stuff and replace it all with MDSPAN_FUNCTION

static constexpr auto static_value() { return static_array_type::static_value(0); }

MDSPAN_INLINE_FUNCTION
Expand Down Expand Up @@ -203,7 +204,7 @@ class layout_left_padded<PaddingValue>::mapping {
}

public:
#if !MDSPAN_HAS_CXX_20
#if !MDSPAN_HAS_CXX_20 || defined(__NVCC__)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the problem here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NVCC does not like the requires stuff down there. Basically, it ends up saying that there is not default constructor on the device available in C++20 mode.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't able to narrow it down in a simple reproducer.

MDSPAN_INLINE_FUNCTION_DEFAULTED
constexpr mapping()
: mapping(extents_type{})
Expand Down Expand Up @@ -347,7 +348,7 @@ class layout_left_padded<PaddingValue>::mapping {
MDSPAN_INLINE_FUNCTION
constexpr mapping(const _Mapping &other_mapping) noexcept
: padded_stride(padded_stride_type::init_padding(
other_mapping.extents(),
static_cast<extents_type>(other_mapping.extents()),
other_mapping.extents().extent(extent_to_pad_idx))),
exts(other_mapping.extents()) {}

Expand Down Expand Up @@ -566,7 +567,7 @@ class layout_right_padded<PaddingValue>::mapping {
}

public:
#if !MDSPAN_HAS_CXX_20
#if !MDSPAN_HAS_CXX_20 || defined(__NVCC__)
MDSPAN_INLINE_FUNCTION_DEFAULTED
constexpr mapping()
: mapping(extents_type{})
Expand Down Expand Up @@ -707,7 +708,7 @@ class layout_right_padded<PaddingValue>::mapping {
MDSPAN_INLINE_FUNCTION
constexpr mapping(const _Mapping &other_mapping) noexcept
: padded_stride(padded_stride_type::init_padding(
other_mapping.extents(),
static_cast<extents_type>(other_mapping.extents()),
other_mapping.extents().extent(extent_to_pad_idx))),
exts(other_mapping.extents()) {}

Expand Down
4 changes: 4 additions & 0 deletions tests/test_layout_padded_left.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,8 @@ TEST(LayoutLeftTests, construction)
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>()).padded_stride.value(0)), 0);
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 4, 7>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 4, 7>>()).extents()), (Kokkos::extents<std::size_t, 4, 7>()));
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 4, 7>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 4, 7>>()).padded_stride.value(0)), 4);
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 4, 7>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, Kokkos::dynamic_extent, 7>>(Kokkos::extents<std::size_t, Kokkos::dynamic_extent, 7>(4))).extents()), (Kokkos::extents<std::size_t, 4, 7>()));
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 4, 7>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, Kokkos::dynamic_extent, 7>>(Kokkos::extents<std::size_t, Kokkos::dynamic_extent, 7>(4))).padded_stride.value(0)), 4);
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, Kokkos::dynamic_extent, 7>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 4, 7>>()).extents()), (Kokkos::extents<std::size_t, 4, 7>()));
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, Kokkos::dynamic_extent, 7>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 4, 7>>()).padded_stride.value(0)), 4);

Expand All @@ -311,6 +313,8 @@ TEST(LayoutLeftTests, construction)
ASSERT_EQ(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t>>()).extents(), Kokkos::extents<std::size_t>());
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>()).extents()), (Kokkos::extents<std::size_t, 3>()));
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>()).padded_stride.value(0)), 0);
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, Kokkos::dynamic_extent>>(Kokkos::dextents<size_t,1>(3))).extents()), (Kokkos::extents<std::size_t, 3>()));
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, Kokkos::dynamic_extent>>(Kokkos::dextents<size_t,1>(3))).padded_stride.value(0)), 0);
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_right_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 3>>({}, 4)).extents()), (Kokkos::extents<std::size_t, 3>()));
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_right_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 3>>({}, 4)).padded_stride.value(0)), 0);

Expand Down
4 changes: 4 additions & 0 deletions tests/test_layout_padded_right.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,13 +304,17 @@ TEST(LayoutrightTests, construction)
ASSERT_EQ((KokkosEx::layout_right_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>()).padded_stride.value(0)), 0);
ASSERT_EQ((KokkosEx::layout_right_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 7, 5>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 7, 5>>()).extents()), (Kokkos::extents<std::size_t, 7, 5>()));
ASSERT_EQ((KokkosEx::layout_right_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 7, 5>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 7, 5>>()).padded_stride.value(0)), 8);
ASSERT_EQ((KokkosEx::layout_right_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 7, 5>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, Kokkos::dynamic_extent, 5>>(Kokkos::extents<size_t, Kokkos::dynamic_extent, 5>(7))).extents()), (Kokkos::extents<std::size_t, 7, 5>()));
ASSERT_EQ((KokkosEx::layout_right_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 7, 5>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, Kokkos::dynamic_extent, 5>>(Kokkos::extents<size_t, Kokkos::dynamic_extent, 5>(7))).padded_stride.value(0)), 8);
ASSERT_EQ((KokkosEx::layout_right_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 7, Kokkos::dynamic_extent>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 7, 5>>()).extents()), (Kokkos::extents<std::size_t, 7, 5>()));
ASSERT_EQ((KokkosEx::layout_right_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 7, Kokkos::dynamic_extent>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 7, 5>>()).padded_stride.value(0)), 8);

// Construct layout_right_padded mapping from layout_left_padded mapping
ASSERT_EQ(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t>>()).extents(), Kokkos::extents<std::size_t>());
ASSERT_EQ((KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>()).extents()), (Kokkos::extents<std::size_t, 3>()));
ASSERT_EQ((KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>()).padded_stride.value(0)), 0);
ASSERT_EQ((KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, Kokkos::dynamic_extent>>(Kokkos::dextents<size_t, 1>(3))).extents()), (Kokkos::extents<std::size_t, 3>()));
ASSERT_EQ((KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, Kokkos::dynamic_extent>>(Kokkos::dextents<size_t, 1>(3))).padded_stride.value(0)), 0);
ASSERT_EQ((KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_left_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 3>>({}, 4)).extents()), (Kokkos::extents<std::size_t, 3>()));
ASSERT_EQ((KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_left_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 3>>({}, 4)).padded_stride.value(0)), 0);

Expand Down
11 changes: 11 additions & 0 deletions tests/test_submdspan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ using submdspan_test_types =
// layout_right to layout_right Check Extents Preservation
, std::tuple<Kokkos::layout_right, Kokkos::layout_right, Kokkos::extents<size_t,10>, args_t<10>, Kokkos::extents<size_t,10>, Kokkos::full_extent_t>
, std::tuple<Kokkos::layout_right, Kokkos::layout_right, Kokkos::extents<size_t,10>, args_t<10>, Kokkos::extents<size_t,dyn>, std::pair<int,int>>
, std::tuple<Kokkos::layout_right, Kokkos::layout_right, Kokkos::extents<size_t,10>, args_t<10>, Kokkos::extents<size_t,dyn>, std::complex<double>>
, std::tuple<Kokkos::layout_right, Kokkos::layout_right, Kokkos::extents<size_t,10>, args_t<10>, Kokkos::extents<size_t>, int>
, std::tuple<Kokkos::layout_right, Kokkos::layout_right, Kokkos::extents<size_t,10,20>, args_t<10,20>, Kokkos::extents<size_t,10,20>, Kokkos::full_extent_t, Kokkos::full_extent_t>
, std::tuple<Kokkos::layout_right, Kokkos::layout_right, Kokkos::extents<size_t,10,20>, args_t<10,20>, Kokkos::extents<size_t,dyn,20>, std::pair<int,int>, Kokkos::full_extent_t>
Expand Down Expand Up @@ -274,6 +275,10 @@ struct TestSubMDSpan<
return std::pair<int,int>(1,3);
}
MDSPAN_INLINE_FUNCTION
static auto create_slice_arg(std::complex<double>) {
return std::complex<double>{1.,3.};
}
MDSPAN_INLINE_FUNCTION
static auto create_slice_arg(Kokkos::strided_slice<int,int,int>) {
return Kokkos::strided_slice<int,int,int>{1,3,2};
}
Expand All @@ -300,6 +305,12 @@ struct TestSubMDSpan<
}
template<class SrcMDSpan, class SubMDSpan, size_t ... SrcIdx, size_t ... SubIdx, class ... SliceArgs>
MDSPAN_INLINE_FUNCTION
static bool check_submdspan_match(int src_idx, int sub_idx, SrcMDSpan src_mds, SubMDSpan sub_mds, std::index_sequence<SrcIdx...>, std::index_sequence<SubIdx...>, std::complex<double> p, SliceArgs ... slices) {
using idx_t = typename SubMDSpan::index_type;
return (sub_mds.extent(sub_idx)==static_cast<idx_t>(p.imag()-p.real())) && check_submdspan_match(++src_idx, ++sub_idx, src_mds, sub_mds, std::index_sequence<SrcIdx...,2>(), std::index_sequence<SubIdx...,1>(), slices...);
}
template<class SrcMDSpan, class SubMDSpan, size_t ... SrcIdx, size_t ... SubIdx, class ... SliceArgs>
MDSPAN_INLINE_FUNCTION
static bool check_submdspan_match(int src_idx, int sub_idx, SrcMDSpan src_mds, SubMDSpan sub_mds, std::index_sequence<SrcIdx...>, std::index_sequence<SubIdx...>,
Kokkos::strided_slice<int,int,int> p, SliceArgs ... slices) {
using idx_t = typename SubMDSpan::index_type;
Expand Down
Loading