-
Notifications
You must be signed in to change notification settings - Fork 69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix device/host call issues in CUDA (when not using relaxed constexpr workaround) #360
Changes from all commits
a0c6886
56d8424
00969d8
aaa2b87
37283cd
97c1813
0fcdad7
9f9a555
43d0595
b0c2452
e818f4d
f86bfe8
f5017f3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,8 @@ | |
|
||
#include <cstddef> | ||
#include <type_traits> | ||
#include <array> | ||
#include <utility> | ||
|
||
namespace MDSPAN_IMPL_STANDARD_NAMESPACE { | ||
namespace detail { | ||
|
@@ -64,6 +66,104 @@ constexpr struct | |
} | ||
} stride; | ||
|
||
// same as std::integral_constant but with __host__ __device__ annotations on | ||
// the implicit conversion function and the call operator | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need/want conversion from to the std:: counterpart? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. because there are places where we pass in std::integral_constant, and std::integral_constant is the standard approved way to pass certain ways to certain functions. For example |
||
template <class T, T v> | ||
struct integral_constant { | ||
using value_type = T; | ||
using type = integral_constant<T, v>; | ||
|
||
static constexpr T value = v; | ||
|
||
MDSPAN_INLINE_FUNCTION_DEFAULTED | ||
constexpr integral_constant() = default; | ||
|
||
// These interop functions work, because other than the value_type operator | ||
// everything of std::integral_constant works on device (defaulted functions) | ||
MDSPAN_FUNCTION | ||
constexpr integral_constant(std::integral_constant<T,v>) {}; | ||
|
||
MDSPAN_FUNCTION constexpr operator std::integral_constant<T,v>() const noexcept { | ||
return std::integral_constant<T,v>{}; | ||
} | ||
|
||
MDSPAN_FUNCTION constexpr operator value_type() const noexcept { | ||
return value; | ||
} | ||
|
||
MDSPAN_FUNCTION constexpr value_type operator()() const noexcept { | ||
return value; | ||
} | ||
}; | ||
|
||
// The tuple implementation only comes in play when using capabilities | ||
// such as submdspan which require C++17 anyway | ||
#if MDSPAN_HAS_CXX_17 | ||
template<class T, size_t Idx> | ||
struct tuple_member { | ||
using type = T; | ||
static constexpr size_t idx = Idx; | ||
T val; | ||
MDSPAN_FUNCTION constexpr T& get() { return val; } | ||
MDSPAN_FUNCTION constexpr const T& get() const { return val; } | ||
}; | ||
|
||
// A helper class which will be used via a fold expression to | ||
// select the type with the correct Idx in a pack of tuple_member | ||
template<size_t SearchIdx, size_t Idx, class T> | ||
struct tuple_idx_matcher { | ||
using type = tuple_member<T, Idx>; | ||
template<class Other> | ||
MDSPAN_FUNCTION | ||
constexpr auto operator | (Other v) const { | ||
if constexpr (Idx == SearchIdx) { return *this; } | ||
else { return v; } | ||
} | ||
}; | ||
|
||
template<class IdxSeq, class ... Elements> | ||
struct tuple_impl; | ||
|
||
template<size_t ... Idx, class ... Elements> | ||
struct tuple_impl<std::index_sequence<Idx...>, Elements...>: public tuple_member<Elements, Idx> ... { | ||
|
||
MDSPAN_FUNCTION | ||
constexpr tuple_impl(Elements ... vals):tuple_member<Elements, Idx>{vals}... {} | ||
|
||
template<size_t N> | ||
MDSPAN_FUNCTION | ||
constexpr auto& get() { | ||
using base_t = decltype((tuple_idx_matcher<N, Idx, Elements>() | ...) ); | ||
return base_t::type::get(); | ||
} | ||
template<size_t N> | ||
MDSPAN_FUNCTION | ||
constexpr const auto& get() const { | ||
using base_t = decltype((tuple_idx_matcher<N, Idx, Elements>() | ...) ); | ||
return base_t::type::get(); | ||
} | ||
}; | ||
|
||
// A simple tuple-like class for representing slices internally and is compatible with device code | ||
// This doesn't support type access since we don't need it | ||
// This is not meant as an external API | ||
template<class ... Elements> | ||
struct tuple: public tuple_impl<decltype(std::make_index_sequence<sizeof...(Elements)>()), Elements...> { | ||
MDSPAN_FUNCTION | ||
constexpr tuple(Elements ... vals):tuple_impl<decltype(std::make_index_sequence<sizeof...(Elements)>()), Elements ...>(vals ...) {} | ||
}; | ||
|
||
template<size_t Idx, class ... Args> | ||
MDSPAN_FUNCTION | ||
constexpr auto& get(tuple<Args...>& vals) { return vals.template get<Idx>(); } | ||
|
||
template<size_t Idx, class ... Args> | ||
MDSPAN_FUNCTION | ||
constexpr const auto& get(const tuple<Args...>& vals) { return vals.template get<Idx>(); } | ||
|
||
template<class ... Elements> | ||
tuple(Elements ...) -> tuple<Elements...>; | ||
#endif | ||
} // namespace detail | ||
|
||
constexpr struct mdspan_non_standard_tag { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose you mean the rank 0 case