Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix device/host call issues in CUDA (when not using relaxed constexpr workaround) #360

Merged
merged 13 commits into from
Oct 4, 2024
Merged
3 changes: 3 additions & 0 deletions include/experimental/__p0009_bits/layout_left.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,12 @@ class layout_left::mapping {

// Not really public, but currently needed to implement fully constexpr useable submdspan:
template<size_t N, class SizeType, size_t ... E, size_t ... Idx>
MDSPAN_INLINE_FUNCTION
constexpr index_type __get_stride(MDSPAN_IMPL_STANDARD_NAMESPACE::extents<SizeType, E...>,std::integer_sequence<size_t, Idx...>) const {
return _MDSPAN_FOLD_TIMES_RIGHT((Idx<N? __extents.template __extent<Idx>():1),1);
}
template<size_t N>
MDSPAN_INLINE_FUNCTION
constexpr index_type __stride() const noexcept {
return __get_stride<N>(__extents, std::make_index_sequence<extents_type::rank()>());
}
Expand All @@ -255,6 +257,7 @@ class layout_left::mapping {
SliceSpecifiers... slices) const;

template<class... SliceSpecifiers>
MDSPAN_INLINE_FUNCTION
friend constexpr auto submdspan_mapping(
const mapping& src, SliceSpecifiers... slices) {
return src.submdspan_mapping_impl(slices...);
Expand Down
3 changes: 3 additions & 0 deletions include/experimental/__p0009_bits/layout_right.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,12 @@ class layout_right::mapping {

// Not really public, but currently needed to implement fully constexpr useable submdspan:
template<size_t N, class SizeType, size_t ... E, size_t ... Idx>
MDSPAN_INLINE_FUNCTION
constexpr index_type __get_stride(MDSPAN_IMPL_STANDARD_NAMESPACE::extents<SizeType, E...>,std::integer_sequence<size_t, Idx...>) const {
return _MDSPAN_FOLD_TIMES_RIGHT((Idx>N? __extents.template __extent<Idx>():1),1);
}
template<size_t N>
MDSPAN_INLINE_FUNCTION
constexpr index_type __stride() const noexcept {
return __get_stride<N>(__extents, std::make_index_sequence<extents_type::rank()>());
}
Expand All @@ -252,6 +254,7 @@ class layout_right::mapping {
SliceSpecifiers... slices) const;

template<class... SliceSpecifiers>
MDSPAN_INLINE_FUNCTION
friend constexpr auto submdspan_mapping(
const mapping& src, SliceSpecifiers... slices) {
return src.submdspan_mapping_impl(slices...);
Expand Down
36 changes: 18 additions & 18 deletions include/experimental/__p0009_bits/layout_stride.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,28 +197,22 @@ struct layout_stride {
}

template<class IntegralType>
MDSPAN_INLINE_FUNCTION
static constexpr const __strides_storage_t fill_strides(const std::array<IntegralType,extents_type::rank()>& s) {
return __strides_storage_t{static_cast<index_type>(s[Idxs])...};
}

MDSPAN_TEMPLATE_REQUIRES(
class IntegralType,
// The is_convertible condition is added to make sfinae valid
// the extents_type::rank() > 0 is added to avoid use of non-standard zero length c-array
(std::is_convertible<IntegralType, typename extents_type::index_type>::value && (extents_type::rank() > 0))
(std::is_convertible<IntegralType, typename extents_type::index_type>::value)
)
MDSPAN_INLINE_FUNCTION
// despite the requirement some compilers still complain about zero length array during parsing
// making it length 1 now, but since the thing can't be instantiated due to requirement the actual
// instantiation of strides_storage will not fail despite mismatching length
// Need to avoid zero length c-array
static constexpr const __strides_storage_t fill_strides(mdspan_non_standard_tag, const IntegralType (&s)[extents_type::rank()>0?extents_type::rank():1]) {
return __strides_storage_t{static_cast<index_type>(s[Idxs])...};
}

#ifdef __cpp_lib_span
template<class IntegralType>
MDSPAN_INLINE_FUNCTION
static constexpr const __strides_storage_t fill_strides(const std::span<IntegralType,extents_type::rank()>& s) {
return __strides_storage_t{static_cast<index_type>(s[Idxs])...};
}
Expand All @@ -242,10 +236,13 @@ struct layout_stride {
// Can't use defaulted parameter in the __deduction_workaround template because of a bug in MSVC warning C4348.
using __impl = __deduction_workaround<std::make_index_sequence<Extents::rank()>>;

MDSPAN_FUNCTION
static constexpr __strides_storage_t strides_storage(detail::with_rank<0>) {
return {};
}

template <std::size_t N>
MDSPAN_FUNCTION
static constexpr __strides_storage_t strides_storage(detail::with_rank<N>) {
__strides_storage_t s{};

Expand Down Expand Up @@ -273,7 +270,7 @@ struct layout_stride {

//--------------------------------------------------------------------------------

MDSPAN_INLINE_FUNCTION_DEFAULTED constexpr mapping() noexcept
MDSPAN_INLINE_FUNCTION constexpr mapping() noexcept
#if defined(_MDSPAN_USE_ATTRIBUTE_NO_UNIQUE_ADDRESS)
: __members{
#else
Expand All @@ -299,7 +296,6 @@ struct layout_stride {
_MDSPAN_TRAIT(std::is_nothrow_constructible, typename Extents::index_type, const std::remove_const_t<IntegralTypes>&)
)
)
MDSPAN_INLINE_FUNCTION
constexpr
mapping(
extents_type const& e,
Expand Down Expand Up @@ -333,19 +329,16 @@ struct layout_stride {
// MSVC 19.32 does not like using index_type here, requires the typename Extents::index_type
// error C2641: cannot deduce template arguments for 'MDSPAN_IMPL_STANDARD_NAMESPACE::layout_stride::mapping'
_MDSPAN_TRAIT(std::is_convertible, const std::remove_const_t<IntegralTypes>&, typename Extents::index_type) &&
_MDSPAN_TRAIT(std::is_nothrow_constructible, typename Extents::index_type, const std::remove_const_t<IntegralTypes>&) &&
(Extents::rank() > 0)
_MDSPAN_TRAIT(std::is_nothrow_constructible, typename Extents::index_type, const std::remove_const_t<IntegralTypes>&)
)
)
MDSPAN_INLINE_FUNCTION
constexpr
mapping(
mdspan_non_standard_tag,
extents_type const& e,
// despite the requirement some compilers still complain about zero length array during parsing
// making it length 1 now, but since the thing can't be instantiated due to requirement the actual
// instantiation of strides_storage will not fail despite mismatching length
IntegralTypes (&s)[extents_type::rank()>0?extents_type::rank():1]
// Need to avoid zero-length c-array
const IntegralTypes (&s)[extents_type::rank()>0?extents_type::rank():1]
) noexcept
#if defined(_MDSPAN_USE_ATTRIBUTE_NO_UNIQUE_ADDRESS)
: __members{
Expand Down Expand Up @@ -379,7 +372,6 @@ struct layout_stride {
_MDSPAN_TRAIT(std::is_nothrow_constructible, typename Extents::index_type, const std::remove_const_t<IntegralTypes>&)
)
)
MDSPAN_INLINE_FUNCTION
constexpr
mapping(
extents_type const& e,
Expand Down Expand Up @@ -476,7 +468,8 @@ struct layout_stride {
MDSPAN_INLINE_FUNCTION
constexpr index_type required_span_size() const noexcept {
index_type span_size = 1;
for(unsigned r = 0; r < extents_type::rank(); r++) {
// using int here to avoid warning about pointless comparison to 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose you mean the rank 0 case

for(int r = 0; r < static_cast<int>(extents_type::rank()); r++) {
// Return early if any of the extents are zero
if(extents().extent(r)==0) return 0;
span_size += ( static_cast<index_type>(extents().extent(r) - 1 ) * __strides_storage()[r]);
Expand Down Expand Up @@ -509,15 +502,18 @@ struct layout_stride {
MDSPAN_INLINE_FUNCTION static constexpr bool is_unique() noexcept { return true; }

private:
MDSPAN_INLINE_FUNCTION
constexpr bool exhaustive_for_nonzero_span_size() const
{
return required_span_size() == __get_size(extents(), std::make_index_sequence<extents_type::rank()>());
}

MDSPAN_INLINE_FUNCTION
constexpr bool is_exhaustive_impl(detail::with_rank<0>) const
{
return true;
}
MDSPAN_INLINE_FUNCTION
constexpr bool is_exhaustive_impl(detail::with_rank<1>) const
{
if (required_span_size() != static_cast<index_type>(0)) {
Expand All @@ -526,6 +522,7 @@ struct layout_stride {
return stride(0) == 1;
}
template <std::size_t N>
MDSPAN_INLINE_FUNCTION
constexpr bool is_exhaustive_impl(detail::with_rank<N>) const
{
if (required_span_size() != static_cast<index_type>(0)) {
Expand Down Expand Up @@ -627,6 +624,7 @@ struct layout_stride {
SliceSpecifiers... slices) const;

template<class... SliceSpecifiers>
MDSPAN_INLINE_FUNCTION
friend constexpr auto submdspan_mapping(
const mapping& src, SliceSpecifiers... slices) {
return src.submdspan_mapping_impl(slices...);
Expand All @@ -637,10 +635,12 @@ struct layout_stride {
namespace detail {

template <class Layout, class Extents, class Mapping>
MDSPAN_INLINE_FUNCTION
constexpr void validate_strides(with_rank<0>, Layout, const Extents&, const Mapping&)
{}

template <std::size_t N, class Layout, class Extents, class Mapping>
MDSPAN_INLINE_FUNCTION
constexpr void validate_strides(with_rank<N>, Layout, const Extents& ext, const Mapping& other)
{
static_assert(std::is_same<typename Mapping::layout_type, layout_stride>::value &&
Expand Down
100 changes: 100 additions & 0 deletions include/experimental/__p0009_bits/utility.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#include <cstddef>
#include <type_traits>
#include <array>
#include <utility>

namespace MDSPAN_IMPL_STANDARD_NAMESPACE {
namespace detail {
Expand Down Expand Up @@ -64,6 +66,104 @@ constexpr struct
}
} stride;

// same as std::integral_constant but with __host__ __device__ annotations on
// the implicit conversion function and the call operator
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need/want conversion from to the std:: counterpart?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because there are places where we pass in std::integral_constant, and std::integral_constant is the standard approved way to pass certain ways to certain functions. For example submdspan(a, std::integral_constant<int, 1>()) is fine in device code.

template <class T, T v>
struct integral_constant {
using value_type = T;
using type = integral_constant<T, v>;

static constexpr T value = v;

MDSPAN_INLINE_FUNCTION_DEFAULTED
constexpr integral_constant() = default;

// These interop functions work, because other than the value_type operator
// everything of std::integral_constant works on device (defaulted functions)
MDSPAN_FUNCTION
constexpr integral_constant(std::integral_constant<T,v>) {};

MDSPAN_FUNCTION constexpr operator std::integral_constant<T,v>() const noexcept {
return std::integral_constant<T,v>{};
}

MDSPAN_FUNCTION constexpr operator value_type() const noexcept {
return value;
}

MDSPAN_FUNCTION constexpr value_type operator()() const noexcept {
return value;
}
};

// The tuple implementation only comes in play when using capabilities
// such as submdspan which require C++17 anyway
#if MDSPAN_HAS_CXX_17
template<class T, size_t Idx>
struct tuple_member {
using type = T;
static constexpr size_t idx = Idx;
T val;
MDSPAN_FUNCTION constexpr T& get() { return val; }
MDSPAN_FUNCTION constexpr const T& get() const { return val; }
};

// A helper class which will be used via a fold expression to
// select the type with the correct Idx in a pack of tuple_member
template<size_t SearchIdx, size_t Idx, class T>
struct tuple_idx_matcher {
using type = tuple_member<T, Idx>;
template<class Other>
MDSPAN_FUNCTION
constexpr auto operator | (Other v) const {
if constexpr (Idx == SearchIdx) { return *this; }
else { return v; }
}
};

template<class IdxSeq, class ... Elements>
struct tuple_impl;

template<size_t ... Idx, class ... Elements>
struct tuple_impl<std::index_sequence<Idx...>, Elements...>: public tuple_member<Elements, Idx> ... {

MDSPAN_FUNCTION
constexpr tuple_impl(Elements ... vals):tuple_member<Elements, Idx>{vals}... {}

template<size_t N>
MDSPAN_FUNCTION
constexpr auto& get() {
using base_t = decltype((tuple_idx_matcher<N, Idx, Elements>() | ...) );
return base_t::type::get();
}
template<size_t N>
MDSPAN_FUNCTION
constexpr const auto& get() const {
using base_t = decltype((tuple_idx_matcher<N, Idx, Elements>() | ...) );
return base_t::type::get();
}
};

// A simple tuple-like class for representing slices internally and is compatible with device code
// This doesn't support type access since we don't need it
// This is not meant as an external API
template<class ... Elements>
struct tuple: public tuple_impl<decltype(std::make_index_sequence<sizeof...(Elements)>()), Elements...> {
MDSPAN_FUNCTION
constexpr tuple(Elements ... vals):tuple_impl<decltype(std::make_index_sequence<sizeof...(Elements)>()), Elements ...>(vals ...) {}
};

template<size_t Idx, class ... Args>
MDSPAN_FUNCTION
constexpr auto& get(tuple<Args...>& vals) { return vals.template get<Idx>(); }

template<size_t Idx, class ... Args>
MDSPAN_FUNCTION
constexpr const auto& get(const tuple<Args...>& vals) { return vals.template get<Idx>(); }

template<class ... Elements>
tuple(Elements ...) -> tuple<Elements...>;
#endif
} // namespace detail

constexpr struct mdspan_non_standard_tag {
Expand Down
Loading
Loading