Skip to content

Commit

Permalink
include kokkos/mdspan#360 changes
Browse files Browse the repository at this point in the history
  • Loading branch information
nmm0 committed Oct 2, 2024
1 parent 853d1c2 commit 52770ad
Show file tree
Hide file tree
Showing 9 changed files with 265 additions and 34 deletions.
137 changes: 137 additions & 0 deletions tpls/mdspan/include/experimental/__p0009_bits/device_support.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
#pragma once
#include <type_traits>
#include "macros.hpp"
#if !defined(_MDSPAN_HAS_CUDA) && !defined(_MDSPAN_HAS_HIP)
#include <tuple>
#endif

namespace MDSPAN_IMPL_STANDARD_NAMESPACE {
namespace detail {

// same as std::integral_constant but with __host__ __device__ annotations on
// the implicit conversion function and the call operator
template <class T, T v>
struct integral_constant {
using value_type = T;
using type = integral_constant<T, v>;

MDSPAN_INLINE_FUNCTION_DEFAULTED
constexpr integral_constant() = default;

MDSPAN_INLINE_FUNCTION_DEFAULTED
constexpr integral_constant(std::integral_constant<T,v>) {};

static constexpr T value = v;
MDSPAN_INLINE_FUNCTION constexpr operator value_type() const noexcept {
return value;
}
MDSPAN_INLINE_FUNCTION constexpr value_type operator()() const noexcept {
return value;
}
MDSPAN_INLINE_FUNCTION constexpr operator std::integral_constant<T,v>() const noexcept {
return std::integral_constant<T,v>{};
}
};

template<class T, size_t Idx>
struct tuple_member {
using type = T;
static constexpr size_t idx = Idx;
T val;
MDSPAN_FUNCTION constexpr T& get() { return val; }
MDSPAN_FUNCTION constexpr const T& get() const { return val; }
};

template<size_t SearchIdx, size_t Idx, class T>
struct tuple_idx_matcher {
using type = tuple_member<T, Idx>;
template<class Other>
MDSPAN_FUNCTION
constexpr auto operator + (Other v) const {
if constexpr (Idx == SearchIdx) { return *this; }
else { return v; }
}
};

template<class SearchT, size_t Idx, class T>
struct tuple_type_matcher {
using type = tuple_member<T, Idx>;
template<class Other>
MDSPAN_FUNCTION
constexpr auto operator + (Other v) const {
if constexpr (std::is_same_v<T, SearchT>) { return *this; }
else { return v; }
}
};

template<class IdxSeq, class ... Elements>
struct tuple_impl;

template<size_t ... Idx, class ... Elements>
struct tuple_impl<std::index_sequence<Idx...>, Elements...>: public tuple_member<Elements, Idx> ... {

MDSPAN_FUNCTION
constexpr tuple_impl(Elements ... vals):tuple_member<Elements, Idx>{vals}... {}

template<class T>
MDSPAN_FUNCTION
constexpr T& get() {
using base_t = decltype((tuple_type_matcher<T, Idx, Elements>() + ...) );
return base_t::type::get();
}
template<class T>
MDSPAN_FUNCTION
constexpr const T& get() const {
using base_t = decltype((tuple_type_matcher<T, Idx, Elements>() + ...) );
return base_t::type::get();
}

template<size_t N>
MDSPAN_FUNCTION
constexpr auto& get() {
using base_t = decltype((tuple_idx_matcher<N, Idx, Elements>() + ...) );
return base_t::type::get();
}
template<size_t N>
MDSPAN_FUNCTION
constexpr const auto& get() const {
using base_t = decltype((tuple_idx_matcher<N, Idx, Elements>() + ...) );
return base_t::type::get();
}
};

template<class ... Elements>
struct tuple: public tuple_impl<decltype(std::make_index_sequence<sizeof...(Elements)>()), Elements...> {
MDSPAN_FUNCTION
constexpr tuple(Elements ... vals):tuple_impl<decltype(std::make_index_sequence<sizeof...(Elements)>()), Elements ...>(vals ...) {}
};

template<class T, class ... Args>
MDSPAN_FUNCTION
constexpr auto& get(tuple<Args...>& vals) { return vals.template get<T>(); }

template<class T, class ... Args>
MDSPAN_FUNCTION
constexpr const T& get(const tuple<Args...>& vals) { return vals.template get<T>(); }

template<size_t Idx, class ... Args>
MDSPAN_FUNCTION
constexpr auto& get(tuple<Args...>& vals) { return vals.template get<Idx>(); }

template<size_t Idx, class ... Args>
MDSPAN_FUNCTION
constexpr const auto& get(const tuple<Args...>& vals) { return vals.template get<Idx>(); }

template<class ... Elements>
tuple(Elements ...) -> tuple<Elements...>;

template<class T, size_t ... Idx>
constexpr auto c_array_to_std(std::index_sequence<Idx...>, const T(&values)[sizeof...(Idx)]) {
return std::array{values[Idx]...};
}
template<class T, size_t N>
constexpr auto c_array_to_std(const T(&values)[N]) {
return c_array_to_std(std::make_index_sequence<N>(), values);
}
}
}
3 changes: 3 additions & 0 deletions tpls/mdspan/include/experimental/__p0009_bits/layout_left.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,12 @@ class layout_left::mapping {

// Not really public, but currently needed to implement fully constexpr useable submdspan:
template<size_t N, class SizeType, size_t ... E, size_t ... Idx>
MDSPAN_INLINE_FUNCTION
constexpr index_type __get_stride(MDSPAN_IMPL_STANDARD_NAMESPACE::extents<SizeType, E...>,std::integer_sequence<size_t, Idx...>) const {
return _MDSPAN_FOLD_TIMES_RIGHT((Idx<N? __extents.template __extent<Idx>():1),1);
}
template<size_t N>
MDSPAN_INLINE_FUNCTION
constexpr index_type __stride() const noexcept {
return __get_stride<N>(__extents, std::make_index_sequence<extents_type::rank()>());
}
Expand All @@ -255,6 +257,7 @@ class layout_left::mapping {
SliceSpecifiers... slices) const;

template<class... SliceSpecifiers>
MDSPAN_INLINE_FUNCTION
friend constexpr auto submdspan_mapping(
const mapping& src, SliceSpecifiers... slices) {
return src.submdspan_mapping_impl(slices...);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,12 @@ class layout_right::mapping {

// Not really public, but currently needed to implement fully constexpr useable submdspan:
template<size_t N, class SizeType, size_t ... E, size_t ... Idx>
MDSPAN_INLINE_FUNCTION
constexpr index_type __get_stride(MDSPAN_IMPL_STANDARD_NAMESPACE::extents<SizeType, E...>,std::integer_sequence<size_t, Idx...>) const {
return _MDSPAN_FOLD_TIMES_RIGHT((Idx>N? __extents.template __extent<Idx>():1),1);
}
template<size_t N>
MDSPAN_INLINE_FUNCTION
constexpr index_type __stride() const noexcept {
return __get_stride<N>(__extents, std::make_index_sequence<extents_type::rank()>());
}
Expand All @@ -252,6 +254,7 @@ class layout_right::mapping {
SliceSpecifiers... slices) const;

template<class... SliceSpecifiers>
MDSPAN_INLINE_FUNCTION
friend constexpr auto submdspan_mapping(
const mapping& src, SliceSpecifiers... slices) {
return src.submdspan_mapping_impl(slices...);
Expand Down
29 changes: 23 additions & 6 deletions tpls/mdspan/include/experimental/__p0009_bits/layout_stride.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,15 @@ struct layout_stride {
template<class IntegralType>
MDSPAN_INLINE_FUNCTION
static constexpr const __strides_storage_t fill_strides(const std::array<IntegralType,extents_type::rank()>& s) {
return __strides_storage_t{static_cast<index_type>(s[Idxs])...};
// avoid warning for use of host std::array operator[]
#if defined(_MDSPAN_HAS_CUDA) || defined(_MDSPAN_HAS_HIP)
const IntegralType* s_ptr = reinterpret_cast<const IntegralType*>(&s);
#else
const IntegralType *s_ptr = s.data();
#endif
// for rank == 0 the expansion is empty and s_ptr becomes unused
detail::maybe_unused_variable(s_ptr);
return __strides_storage_t{static_cast<index_type>(s_ptr[Idxs])...};
}

MDSPAN_TEMPLATE_REQUIRES(
Expand All @@ -218,7 +226,6 @@ struct layout_stride {

#ifdef __cpp_lib_span
template<class IntegralType>
MDSPAN_INLINE_FUNCTION
static constexpr const __strides_storage_t fill_strides(const std::span<IntegralType,extents_type::rank()>& s) {
return __strides_storage_t{static_cast<index_type>(s[Idxs])...};
}
Expand All @@ -242,10 +249,13 @@ struct layout_stride {
// Can't use defaulted parameter in the __deduction_workaround template because of a bug in MSVC warning C4348.
using __impl = __deduction_workaround<std::make_index_sequence<Extents::rank()>>;

MDSPAN_FUNCTION
static constexpr __strides_storage_t strides_storage(detail::with_rank<0>) {
return {};
}

template <std::size_t N>
MDSPAN_FUNCTION
static constexpr __strides_storage_t strides_storage(detail::with_rank<N>) {
__strides_storage_t s{};

Expand Down Expand Up @@ -273,7 +283,7 @@ struct layout_stride {

//--------------------------------------------------------------------------------

MDSPAN_INLINE_FUNCTION_DEFAULTED constexpr mapping() noexcept
MDSPAN_INLINE_FUNCTION constexpr mapping() noexcept
#if defined(_MDSPAN_USE_ATTRIBUTE_NO_UNIQUE_ADDRESS)
: __members{
#else
Expand All @@ -299,7 +309,7 @@ struct layout_stride {
_MDSPAN_TRAIT(std::is_nothrow_constructible, typename Extents::index_type, const std::remove_const_t<IntegralTypes>&)
)
)
MDSPAN_INLINE_FUNCTION
MDSPAN_FUNCTION
constexpr
mapping(
extents_type const& e,
Expand Down Expand Up @@ -379,7 +389,6 @@ struct layout_stride {
_MDSPAN_TRAIT(std::is_nothrow_constructible, typename Extents::index_type, const std::remove_const_t<IntegralTypes>&)
)
)
MDSPAN_INLINE_FUNCTION
constexpr
mapping(
extents_type const& e,
Expand Down Expand Up @@ -476,7 +485,8 @@ struct layout_stride {
MDSPAN_INLINE_FUNCTION
constexpr index_type required_span_size() const noexcept {
index_type span_size = 1;
for(unsigned r = 0; r < extents_type::rank(); r++) {
// using int here to avoid warning about pointless comparison to 0
for(int r = 0; r < static_cast<int>(extents_type::rank()); r++) {
// Return early if any of the extents are zero
if(extents().extent(r)==0) return 0;
span_size += ( static_cast<index_type>(extents().extent(r) - 1 ) * __strides_storage()[r]);
Expand Down Expand Up @@ -509,15 +519,18 @@ struct layout_stride {
MDSPAN_INLINE_FUNCTION static constexpr bool is_unique() noexcept { return true; }

private:
MDSPAN_INLINE_FUNCTION
constexpr bool exhaustive_for_nonzero_span_size() const
{
return required_span_size() == __get_size(extents(), std::make_index_sequence<extents_type::rank()>());
}

MDSPAN_INLINE_FUNCTION
constexpr bool is_exhaustive_impl(detail::with_rank<0>) const
{
return true;
}
MDSPAN_INLINE_FUNCTION
constexpr bool is_exhaustive_impl(detail::with_rank<1>) const
{
if (required_span_size() != static_cast<index_type>(0)) {
Expand All @@ -526,6 +539,7 @@ struct layout_stride {
return stride(0) == 1;
}
template <std::size_t N>
MDSPAN_INLINE_FUNCTION
constexpr bool is_exhaustive_impl(detail::with_rank<N>) const
{
if (required_span_size() != static_cast<index_type>(0)) {
Expand Down Expand Up @@ -627,6 +641,7 @@ struct layout_stride {
SliceSpecifiers... slices) const;

template<class... SliceSpecifiers>
MDSPAN_INLINE_FUNCTION
friend constexpr auto submdspan_mapping(
const mapping& src, SliceSpecifiers... slices) {
return src.submdspan_mapping_impl(slices...);
Expand All @@ -637,10 +652,12 @@ struct layout_stride {
namespace detail {

template <class Layout, class Extents, class Mapping>
MDSPAN_INLINE_FUNCTION
constexpr void validate_strides(with_rank<0>, Layout, const Extents&, const Mapping&)
{}

template <std::size_t N, class Layout, class Extents, class Mapping>
MDSPAN_INLINE_FUNCTION
constexpr void validate_strides(with_rank<N>, Layout, const Extents& ext, const Mapping& other)
{
static_assert(std::is_same<typename Mapping::layout_type, layout_stride>::value &&
Expand Down
4 changes: 4 additions & 0 deletions tpls/mdspan/include/experimental/__p0009_bits/utility.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ constexpr struct
}
} stride;

template<class T>
MDSPAN_INLINE_FUNCTION
constexpr void maybe_unused_variable(const T&) {}

} // namespace detail

constexpr struct mdspan_non_standard_tag {
Expand Down
Loading

0 comments on commit 52770ad

Please sign in to comment.