Skip to content

Commit

Permalink
Fix hash table full occupancy bug
Browse files Browse the repository at this point in the history
  • Loading branch information
PointKernel committed Nov 28, 2024
1 parent b8429d4 commit 740dbae
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 27 deletions.
86 changes: 59 additions & 27 deletions include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,10 @@ class open_addressing_ref_impl {
{
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");

auto const val = this->heterogeneous_value(value);
auto const key = this->extract_key(val);
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
auto const val = this->heterogeneous_value(value);
auto const key = this->extract_key(val);
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;

while (true) {
auto const bucket_slots = storage_ref_[*probing_iter];
Expand Down Expand Up @@ -411,6 +412,7 @@ class open_addressing_ref_impl {
}
}
++probing_iter;
if (*probing_iter == init_idx) { return false; }
}
}

Expand All @@ -428,9 +430,10 @@ class open_addressing_ref_impl {
__device__ bool insert(cooperative_groups::thread_block_tile<cg_size> const& group,
Value const& value) noexcept
{
auto const val = this->heterogeneous_value(value);
auto const key = this->extract_key(val);
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
auto const val = this->heterogeneous_value(value);
auto const key = this->extract_key(val);
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;

while (true) {
auto const bucket_slots = storage_ref_[*probing_iter];
Expand Down Expand Up @@ -483,6 +486,7 @@ class open_addressing_ref_impl {
}
} else {
++probing_iter;
if (*probing_iter == init_idx) { return false; }
}
}
}
Expand Down Expand Up @@ -513,9 +517,10 @@ class open_addressing_ref_impl {
"insert_and_find is not supported for pair types larger than 8 bytes on pre-Volta GPUs.");
#endif

auto const val = this->heterogeneous_value(value);
auto const key = this->extract_key(val);
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
auto const val = this->heterogeneous_value(value);
auto const key = this->extract_key(val);
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;

while (true) {
auto const bucket_slots = storage_ref_[*probing_iter];
Expand Down Expand Up @@ -554,6 +559,7 @@ class open_addressing_ref_impl {
}
}
++probing_iter;
if (*probing_iter == init_idx) { return {this->end(), false}; }
};
}

Expand Down Expand Up @@ -584,9 +590,10 @@ class open_addressing_ref_impl {
"insert_and_find is not supported for pair types larger than 8 bytes on pre-Volta GPUs.");
#endif

auto const val = this->heterogeneous_value(value);
auto const key = this->extract_key(val);
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
auto const val = this->heterogeneous_value(value);
auto const key = this->extract_key(val);
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;

while (true) {
auto const bucket_slots = storage_ref_[*probing_iter];
Expand Down Expand Up @@ -653,6 +660,7 @@ class open_addressing_ref_impl {
}
} else {
++probing_iter;
if (*probing_iter == init_idx) { return {this->end(), false}; }
}
}
}
Expand All @@ -671,7 +679,8 @@ class open_addressing_ref_impl {
{
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");

auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;

while (true) {
auto const bucket_slots = storage_ref_[*probing_iter];
Expand All @@ -696,6 +705,7 @@ class open_addressing_ref_impl {
}
}
++probing_iter;
if (*probing_iter == init_idx) { return false; }
}
}

Expand All @@ -713,7 +723,8 @@ class open_addressing_ref_impl {
__device__ bool erase(cooperative_groups::thread_block_tile<cg_size> const& group,
ProbeKey const& key) noexcept
{
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;

while (true) {
auto const bucket_slots = storage_ref_[*probing_iter];
Expand Down Expand Up @@ -750,6 +761,7 @@ class open_addressing_ref_impl {
if (group.any(state == detail::equal_result::EMPTY)) { return false; }

++probing_iter;
if (*probing_iter == init_idx) { return false; }
}
}

Expand All @@ -769,7 +781,8 @@ class open_addressing_ref_impl {
[[nodiscard]] __device__ bool contains(ProbeKey const& key) const noexcept
{
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;

while (true) {
// TODO atomic_ref::load if insert operator is present
Expand All @@ -783,6 +796,7 @@ class open_addressing_ref_impl {
}
}
++probing_iter;
if (*probing_iter == init_idx) { return false; }
}
}

Expand All @@ -803,7 +817,8 @@ class open_addressing_ref_impl {
[[nodiscard]] __device__ bool contains(
cooperative_groups::thread_block_tile<cg_size> const& group, ProbeKey const& key) const noexcept
{
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;

while (true) {
auto const bucket_slots = storage_ref_[*probing_iter];
Expand All @@ -821,6 +836,7 @@ class open_addressing_ref_impl {
if (group.any(state == detail::equal_result::EMPTY)) { return false; }

++probing_iter;
if (*probing_iter == init_idx) { return false; }
}
}

Expand All @@ -840,7 +856,8 @@ class open_addressing_ref_impl {
[[nodiscard]] __device__ const_iterator find(ProbeKey const& key) const noexcept
{
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;

while (true) {
// TODO atomic_ref::load if insert operator is present
Expand All @@ -859,6 +876,7 @@ class open_addressing_ref_impl {
}
}
++probing_iter;
if (*probing_iter == init_idx) { return this->end(); }
}
}

Expand All @@ -879,7 +897,8 @@ class open_addressing_ref_impl {
[[nodiscard]] __device__ const_iterator find(
cooperative_groups::thread_block_tile<cg_size> const& group, ProbeKey const& key) const noexcept
{
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;

while (true) {
auto const bucket_slots = storage_ref_[*probing_iter];
Expand Down Expand Up @@ -908,6 +927,7 @@ class open_addressing_ref_impl {
if (group.any(state == detail::equal_result::EMPTY)) { return this->end(); }

++probing_iter;
if (*probing_iter == init_idx) { return this->end(); }
}
}

Expand All @@ -926,8 +946,9 @@ class open_addressing_ref_impl {
if constexpr (not allows_duplicates) {
return static_cast<size_type>(this->contains(key));
} else {
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
size_type count = 0;
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;
size_type count = 0;

while (true) {
// TODO atomic_ref::load if insert operator is present
Expand All @@ -942,6 +963,7 @@ class open_addressing_ref_impl {
}
}
++probing_iter;
if (*probing_iter == init_idx) { return count; }
}
}
}
Expand All @@ -960,8 +982,9 @@ class open_addressing_ref_impl {
[[nodiscard]] __device__ size_type count(
cooperative_groups::thread_block_tile<cg_size> const& group, ProbeKey const& key) const noexcept
{
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
size_type count = 0;
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;
size_type count = 0;

while (true) {
auto const bucket_slots = storage_ref_[*probing_iter];
Expand All @@ -978,6 +1001,7 @@ class open_addressing_ref_impl {

if (group.any(state == detail::equal_result::EMPTY)) { return count; }
++probing_iter;
if (*probing_iter == init_idx) { return count; }
}
}

Expand Down Expand Up @@ -1177,6 +1201,7 @@ class open_addressing_ref_impl {
auto const& probe_key = *(input_probe + idx);
auto probing_iter =
this->probing_scheme_(probing_tile, probe_key, this->storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;

bool running = true;
[[maybe_unused]] bool found_match = false;
Expand Down Expand Up @@ -1277,6 +1302,7 @@ class open_addressing_ref_impl {

// onto the next probing bucket
++probing_iter;
if (*probing_iter == init_idx) { running = false; }
} // while running
} // if active_flag

Expand Down Expand Up @@ -1305,7 +1331,8 @@ class open_addressing_ref_impl {
__device__ void for_each(ProbeKey const& key, CallbackOp&& callback_op) const noexcept
{
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
auto probing_iter = this->probing_scheme_(key, this->storage_ref_.bucket_extent());
auto probing_iter = this->probing_scheme_(key, this->storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;

while (true) {
// TODO atomic_ref::load if insert operator is present
Expand All @@ -1325,6 +1352,7 @@ class open_addressing_ref_impl {
}
}
++probing_iter;
if (*probing_iter == init_idx) { return; }
}
}

Expand Down Expand Up @@ -1352,8 +1380,9 @@ class open_addressing_ref_impl {
ProbeKey const& key,
CallbackOp&& callback_op) const noexcept
{
auto probing_iter = this->probing_scheme_(group, key, this->storage_ref_.bucket_extent());
bool empty = false;
auto probing_iter = this->probing_scheme_(group, key, this->storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;
bool empty = false;

while (true) {
// TODO atomic_ref::load if insert operator is present
Expand All @@ -1378,6 +1407,7 @@ class open_addressing_ref_impl {
if (group.any(empty)) { return; }

++probing_iter;
if (*probing_iter == init_idx) { return; }
}
}

Expand Down Expand Up @@ -1414,8 +1444,9 @@ class open_addressing_ref_impl {
CallbackOp&& callback_op,
SyncOp&& sync_op) const noexcept
{
auto probing_iter = this->probing_scheme_(group, key, this->storage_ref_.bucket_extent());
bool empty = false;
auto probing_iter = this->probing_scheme_(group, key, this->storage_ref_.bucket_extent());
auto const init_idx = *probing_iter;
bool empty = false;

while (true) {
// TODO atomic_ref::load if insert operator is present
Expand All @@ -1441,6 +1472,7 @@ class open_addressing_ref_impl {
if (group.any(empty)) { return; }

++probing_iter;
if (*probing_iter == init_idx) { return; }
}
}

Expand Down
8 changes: 8 additions & 0 deletions include/cuco/detail/static_map/static_map_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ class operator_impl<
auto& probing_scheme = ref_.impl_.probing_scheme();
auto storage_ref = ref_.impl_.storage_ref();
auto probing_iter = probing_scheme(key, storage_ref.bucket_extent());
auto const init_idx = *probing_iter;

while (true) {
auto const bucket_slots = storage_ref[*probing_iter];
Expand All @@ -514,6 +515,7 @@ class operator_impl<
}
}
++probing_iter;
if (*probing_iter == init_idx) { return; }
}
}

Expand All @@ -539,6 +541,7 @@ class operator_impl<
auto& probing_scheme = ref_.impl_.probing_scheme();
auto storage_ref = ref_.impl_.storage_ref();
auto probing_iter = probing_scheme(group, key, storage_ref.bucket_extent());
auto const init_idx = *probing_iter;

while (true) {
auto const bucket_slots = storage_ref[*probing_iter];
Expand Down Expand Up @@ -578,6 +581,7 @@ class operator_impl<
if (group.shfl(status, src_lane)) { return; }
} else {
++probing_iter;
if (*probing_iter == init_idx) { return; }
}
}
}
Expand Down Expand Up @@ -855,6 +859,7 @@ class operator_impl<
auto& probing_scheme = ref_.impl_.probing_scheme();
auto storage_ref = ref_.impl_.storage_ref();
auto probing_iter = probing_scheme(key, storage_ref.bucket_extent());
auto const init_idx = *probing_iter;
auto const empty_value = ref_.empty_value_sentinel();

// wait for payload only when init != sentinel and insert strategy is not `packed_cas`
Expand Down Expand Up @@ -894,6 +899,7 @@ class operator_impl<
}
}
++probing_iter;
if (*probing_iter == init_idx) { return false; }
}
}

Expand Down Expand Up @@ -929,6 +935,7 @@ class operator_impl<
auto& probing_scheme = ref_.impl_.probing_scheme();
auto storage_ref = ref_.impl_.storage_ref();
auto probing_iter = probing_scheme(group, key, storage_ref.bucket_extent());
auto const init_idx = *probing_iter;
auto const empty_value = ref_.empty_value_sentinel();

// wait for payload only when init != sentinel and insert strategy is not `packed_cas`
Expand Down Expand Up @@ -987,6 +994,7 @@ class operator_impl<
}
} else {
++probing_iter;
if (*probing_iter == init_idx) { return false; }
}
}
}
Expand Down

0 comments on commit 740dbae

Please sign in to comment.