Skip to content

Commit

Permalink
Improve rebind utilities for cuco hash tables (#598)
Browse files Browse the repository at this point in the history
This PR renames all `with_*` member functions to `rebind_*` for improved
clarity.

The legacy `with_operators` will be removed once libcudf is migrated to
use the new `rebind_operators`.

---------

Co-authored-by: Daniel Jünger <[email protected]>
  • Loading branch information
PointKernel and sleeepyjack authored Sep 17, 2024
1 parent 9ef3535 commit 5602381
Show file tree
Hide file tree
Showing 13 changed files with 280 additions and 157 deletions.
16 changes: 4 additions & 12 deletions include/cuco/detail/probing_scheme/probing_scheme_impl.inl
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ __host__ __device__ constexpr linear_probing<CGSize, Hash>::linear_probing(Hash

template <int32_t CGSize, typename Hash>
template <typename NewHash>
__host__ __device__ constexpr auto linear_probing<CGSize, Hash>::with_hash_function(
__host__ __device__ constexpr auto linear_probing<CGSize, Hash>::rebind_hash_function(
NewHash const& hash) const noexcept
{
return linear_probing<cg_size, NewHash>{hash};
Expand Down Expand Up @@ -143,28 +143,20 @@ __host__ __device__ constexpr double_hashing<CGSize, Hash1, Hash2>::double_hashi

template <int32_t CGSize, typename Hash1, typename Hash2>
__host__ __device__ constexpr double_hashing<CGSize, Hash1, Hash2>::double_hashing(
cuco::pair<Hash1, Hash2> const& hash)
cuda::std::tuple<Hash1, Hash2> const& hash)
: hash1_{hash.first}, hash2_{hash.second}
{
}

template <int32_t CGSize, typename Hash1, typename Hash2>
template <typename NewHash1, typename NewHash2>
__host__ __device__ constexpr auto double_hashing<CGSize, Hash1, Hash2>::with_hash_function(
NewHash1 const& hash1, NewHash2 const& hash2) const noexcept
{
return double_hashing<cg_size, NewHash1, NewHash2>{hash1, hash2};
}

template <int32_t CGSize, typename Hash1, typename Hash2>
template <typename NewHash, typename Enable>
__host__ __device__ constexpr auto double_hashing<CGSize, Hash1, Hash2>::with_hash_function(
__host__ __device__ constexpr auto double_hashing<CGSize, Hash1, Hash2>::rebind_hash_function(
NewHash const& hash) const
{
static_assert(cuco::is_tuple_like<NewHash>::value,
"The given hasher must be a tuple-like object");

auto const [hash1, hash2] = cuco::pair{hash};
auto const [hash1, hash2] = cuda::std::tuple{hash};
using hash1_type = cuda::std::decay_t<decltype(hash1)>;
using hash2_type = cuda::std::decay_t<decltype(hash2)>;
return double_hashing<cg_size, hash1_type, hash2_type>{hash1, hash2};
Expand Down
4 changes: 2 additions & 2 deletions include/cuco/detail/static_map/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ CUCO_KERNEL __launch_bounds__(BlockSize) void insert_or_apply_shmem(
ref.probing_scheme(),
{},
storage};
auto shared_map_ref = std::move(shared_map).with(cuco::op::insert_or_apply);
auto shared_map_ref = shared_map.rebind_operators(cuco::op::insert_or_apply);
shared_map_ref.initialize(block);
block.sync();

Expand Down Expand Up @@ -262,4 +262,4 @@ CUCO_KERNEL __launch_bounds__(BlockSize) void insert_or_apply_shmem(
}
}
}
} // namespace cuco::static_map_ns::detail
} // namespace cuco::static_map_ns::detail
77 changes: 63 additions & 14 deletions include/cuco/detail/static_map/static_map_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -296,11 +296,17 @@ template <typename Key,
typename StorageRef,
typename... Operators>
template <typename... NewOperators>
auto static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::with(
NewOperators...) && noexcept
__host__ __device__ constexpr auto
static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::with_operators(
NewOperators...) const noexcept
{
return static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, NewOperators...>{
std::move(*this)};
cuco::empty_key<Key>{this->empty_key_sentinel()},
cuco::empty_value<T>{this->empty_value_sentinel()},
this->key_eq(),
this->probing_scheme(),
{},
this->storage_ref()};
}

template <typename Key,
Expand All @@ -311,22 +317,65 @@ template <typename Key,
typename StorageRef,
typename... Operators>
template <typename... NewOperators>
__host__ __device__ auto constexpr static_map_ref<Key,
T,
Scope,
KeyEqual,
ProbingScheme,
StorageRef,
Operators...>::with_operators(NewOperators...)
const noexcept
__host__ __device__ constexpr auto
static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::rebind_operators(
NewOperators...) const noexcept
{
return static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, NewOperators...>{
cuco::empty_key<Key>{this->empty_key_sentinel()},
cuco::empty_value<T>{this->empty_value_sentinel()},
this->key_eq(),
this->impl_.probing_scheme(),
this->probing_scheme(),
{},
this->impl_.storage_ref()};
this->storage_ref()};
}

template <typename Key,
typename T,
cuda::thread_scope Scope,
typename KeyEqual,
typename ProbingScheme,
typename StorageRef,
typename... Operators>
template <typename NewKeyEqual>
__host__ __device__ constexpr auto
static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::rebind_key_eq(
NewKeyEqual const& key_equal) const noexcept
{
return static_map_ref<Key, T, Scope, NewKeyEqual, ProbingScheme, StorageRef, Operators...>{
cuco::empty_key<Key>{this->empty_key_sentinel()},
cuco::empty_value<T>{this->empty_value_sentinel()},
key_equal,
this->probing_scheme(),
{},
this->storage_ref()};
}

template <typename Key,
typename T,
cuda::thread_scope Scope,
typename KeyEqual,
typename ProbingScheme,
typename StorageRef,
typename... Operators>
template <typename NewHash>
__host__ __device__ constexpr auto
static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::
rebind_hash_function(NewHash const& hash) const
{
auto const probing_scheme = this->probing_scheme().rebind_hash_function(hash);
return static_map_ref<Key,
T,
Scope,
KeyEqual,
cuda::std::decay_t<decltype(probing_scheme)>,
StorageRef,
Operators...>{cuco::empty_key<Key>{this->empty_key_sentinel()},
cuco::empty_value<T>{this->empty_value_sentinel()},
this->key_eq(),
probing_scheme,
{},
this->storage_ref()};
}

template <typename Key,
Expand All @@ -349,7 +398,7 @@ static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>
cuco::empty_value<T>{this->empty_value_sentinel()},
cuco::erased_key<Key>{this->erased_key_sentinel()},
this->key_eq(),
this->impl_.probing_scheme(),
this->probing_scheme(),
scope,
storage_ref_type{this->window_extent(), memory_to_use}};
}
Expand Down
69 changes: 64 additions & 5 deletions include/cuco/detail/static_multimap/static_multimap_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,22 @@ template <typename Key,
typename StorageRef,
typename... Operators>
template <typename... NewOperators>
auto static_multimap_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::with(
NewOperators...) && noexcept
__host__ __device__ auto constexpr static_multimap_ref<
Key,
T,
Scope,
KeyEqual,
ProbingScheme,
StorageRef,
Operators...>::with_operators(NewOperators...) const noexcept
{
return static_multimap_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, NewOperators...>{
std::move(*this)};
cuco::empty_key<Key>{this->empty_key_sentinel()},
cuco::empty_value<T>{this->empty_value_sentinel()},
this->key_eq(),
this->probing_scheme(),
{},
impl_.storage_ref()};
}

template <typename Key,
Expand All @@ -317,15 +328,63 @@ __host__ __device__ auto constexpr static_multimap_ref<
KeyEqual,
ProbingScheme,
StorageRef,
Operators...>::with_operators(NewOperators...) const noexcept
Operators...>::rebind_operators(NewOperators...) const noexcept
{
return static_multimap_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, NewOperators...>{
cuco::empty_key<Key>{this->empty_key_sentinel()},
cuco::empty_value<T>{this->empty_value_sentinel()},
this->key_eq(),
impl_.probing_scheme(),
{},
impl_.storage_ref()};
this->storage_ref()};
}

template <typename Key,
typename T,
cuda::thread_scope Scope,
typename KeyEqual,
typename ProbingScheme,
typename StorageRef,
typename... Operators>
template <typename NewKeyEqual>
__host__ __device__ constexpr auto
static_multimap_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::
rebind_key_eq(NewKeyEqual const& key_equal) const noexcept
{
return static_multimap_ref<Key, T, Scope, NewKeyEqual, ProbingScheme, StorageRef, Operators...>{
cuco::empty_key<Key>{this->empty_key_sentinel()},
cuco::empty_value<T>{this->empty_value_sentinel()},
key_equal,
this->probing_scheme(),
{},
this->storage_ref()};
}

template <typename Key,
typename T,
cuda::thread_scope Scope,
typename KeyEqual,
typename ProbingScheme,
typename StorageRef,
typename... Operators>
template <typename NewHash>
__host__ __device__ constexpr auto
static_multimap_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::
rebind_hash_function(NewHash const& hash) const
{
auto const probing_scheme = this->probing_scheme().rebind_hash_function(hash);
return static_multimap_ref<Key,
T,
Scope,
KeyEqual,
cuda::std::decay_t<decltype(probing_scheme)>,
StorageRef,
Operators...>{cuco::empty_key<Key>{this->empty_key_sentinel()},
cuco::empty_value<T>{this->empty_value_sentinel()},
this->key_eq(),
probing_scheme,
{},
this->storage_ref()};
}

template <typename Key,
Expand Down
11 changes: 6 additions & 5 deletions include/cuco/detail/static_multiset/static_multiset.inl
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,11 @@ static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>
ProbeHash const& probe_hash,
cuda::stream_ref stream) const
{
return impl_->count(first,
last,
ref(op::count).with_key_eq(probe_key_equal).with_hash_function(probe_hash),
stream);
return impl_->count(
first,
last,
ref(op::count).rebind_key_eq(probe_key_equal).rebind_hash_function(probe_hash),
stream);
}

template <class Key,
Expand All @@ -333,7 +334,7 @@ static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>
return impl_->count_outer(
first,
last,
ref(op::count).with_key_eq(probe_key_equal).with_hash_function(probe_hash),
ref(op::count).rebind_key_eq(probe_key_equal).rebind_hash_function(probe_hash),
stream);
}

Expand Down
33 changes: 19 additions & 14 deletions include/cuco/detail/static_multiset/static_multiset_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,16 @@ template <typename Key,
typename StorageRef,
typename... Operators>
template <typename... NewOperators>
auto static_multiset_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::with(
NewOperators...) && noexcept
__host__ __device__ constexpr auto
static_multiset_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::with_operators(
NewOperators...) const noexcept
{
return static_multiset_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, NewOperators...>{
std::move(*this)};
cuco::empty_key<Key>{this->empty_key_sentinel()},
this->key_eq(),
this->probing_scheme(),
{},
this->storage_ref()};
}

template <typename Key,
Expand All @@ -266,15 +271,15 @@ template <typename Key,
typename... Operators>
template <typename... NewOperators>
__host__ __device__ constexpr auto
static_multiset_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::with_operators(
NewOperators...) const noexcept
static_multiset_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::
rebind_operators(NewOperators...) const noexcept
{
return static_multiset_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, NewOperators...>{
cuco::empty_key<Key>{this->empty_key_sentinel()},
this->key_eq(),
this->impl_.probing_scheme(),
this->probing_scheme(),
{},
this->impl_.storage_ref()};
this->storage_ref()};
}

template <typename Key,
Expand All @@ -285,15 +290,15 @@ template <typename Key,
typename... Operators>
template <typename NewKeyEqual>
__host__ __device__ constexpr auto
static_multiset_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::with_key_eq(
static_multiset_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::rebind_key_eq(
NewKeyEqual const& key_equal) const noexcept
{
return static_multiset_ref<Key, Scope, NewKeyEqual, ProbingScheme, StorageRef, Operators...>{
cuco::empty_key<Key>{this->empty_key_sentinel()},
key_equal,
this->impl_.probing_scheme(),
this->probing_scheme(),
{},
this->impl_.storage_ref()};
this->storage_ref()};
}

template <typename Key,
Expand All @@ -305,19 +310,19 @@ template <typename Key,
template <typename NewHash>
__host__ __device__ constexpr auto
static_multiset_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::
with_hash_function(NewHash const& hash) const
rebind_hash_function(NewHash const& hash) const
{
auto const probing_scheme = this->impl_.probing_scheme().with_hash_function(hash);
auto const probing_scheme = this->probing_scheme().rebind_hash_function(hash);
return static_multiset_ref<Key,
Scope,
KeyEqual,
cuda::std::decay_t<decltype(probing_scheme)>,
StorageRef,
Operators...>{cuco::empty_key<Key>{this->empty_key_sentinel()},
this->impl_.key_eq(),
this->key_eq(),
probing_scheme,
{},
this->impl_.storage_ref()};
this->storage_ref()};
}

namespace detail {
Expand Down
Loading

0 comments on commit 5602381

Please sign in to comment.