diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index 1c7fa518a..1676fb816 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -378,9 +378,10 @@ class open_addressing_ref_impl { { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); - auto const val = this->heterogeneous_value(value); - auto const key = this->extract_key(val); - auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent()); + auto const val = this->heterogeneous_value(value); + auto const key = this->extract_key(val); + auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent()); + auto const init_idx = *probing_iter; while (true) { auto const bucket_slots = storage_ref_[*probing_iter]; @@ -411,6 +412,7 @@ class open_addressing_ref_impl { } } ++probing_iter; + if (*probing_iter == init_idx) { return false; } } } @@ -428,9 +430,10 @@ class open_addressing_ref_impl { __device__ bool insert(cooperative_groups::thread_block_tile const& group, Value const& value) noexcept { - auto const val = this->heterogeneous_value(value); - auto const key = this->extract_key(val); - auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent()); + auto const val = this->heterogeneous_value(value); + auto const key = this->extract_key(val); + auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent()); + auto const init_idx = *probing_iter; while (true) { auto const bucket_slots = storage_ref_[*probing_iter]; @@ -483,6 +486,7 @@ class open_addressing_ref_impl { } } else { ++probing_iter; + if (*probing_iter == init_idx) { return false; } } } } @@ -513,9 +517,10 @@ class open_addressing_ref_impl { "insert_and_find is not supported for pair types larger than 8 bytes on pre-Volta GPUs."); #endif - auto const val = this->heterogeneous_value(value); - auto const key = this->extract_key(val); - auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent()); + auto const val = this->heterogeneous_value(value); + auto const key = this->extract_key(val); + auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent()); + auto const init_idx = *probing_iter; while (true) { auto const bucket_slots = storage_ref_[*probing_iter]; @@ -554,6 +559,7 @@ class open_addressing_ref_impl { } } ++probing_iter; + if (*probing_iter == init_idx) { return {this->end(), false}; } }; } @@ -584,9 +590,10 @@ class open_addressing_ref_impl { "insert_and_find is not supported for pair types larger than 8 bytes on pre-Volta GPUs."); #endif - auto const val = this->heterogeneous_value(value); - auto const key = this->extract_key(val); - auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent()); + auto const val = this->heterogeneous_value(value); + auto const key = this->extract_key(val); + auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent()); + auto const init_idx = *probing_iter; while (true) { auto const bucket_slots = storage_ref_[*probing_iter]; @@ -653,6 +660,7 @@ class open_addressing_ref_impl { } } else { ++probing_iter; + if (*probing_iter == init_idx) { return {this->end(), false}; } } } } @@ -671,7 +679,8 @@ class open_addressing_ref_impl { { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); - auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent()); + auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent()); + auto const init_idx = *probing_iter; while (true) { auto const bucket_slots = storage_ref_[*probing_iter]; @@ -696,6 +705,7 @@ class open_addressing_ref_impl { } } ++probing_iter; + if (*probing_iter == init_idx) { return false; } } } @@ -713,7 +723,8 @@ class open_addressing_ref_impl { __device__ bool erase(cooperative_groups::thread_block_tile const& group, ProbeKey const& key) noexcept { - auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent()); + auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent()); + auto const init_idx = *probing_iter; while (true) { auto const bucket_slots = storage_ref_[*probing_iter]; @@ -750,6 +761,7 @@ class open_addressing_ref_impl { if (group.any(state == detail::equal_result::EMPTY)) { return false; } ++probing_iter; + if (*probing_iter == init_idx) { return false; } } } @@ -769,7 +781,8 @@ class open_addressing_ref_impl { [[nodiscard]] __device__ bool contains(ProbeKey const& key) const noexcept { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); - auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent()); + auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent()); + auto const init_idx = *probing_iter; while (true) { // TODO atomic_ref::load if insert operator is present @@ -783,6 +796,7 @@ class open_addressing_ref_impl { } } ++probing_iter; + if (*probing_iter == init_idx) { return false; } } } @@ -803,7 +817,8 @@ class open_addressing_ref_impl { [[nodiscard]] __device__ bool contains( cooperative_groups::thread_block_tile const& group, ProbeKey const& key) const noexcept { - auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent()); + auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent()); + auto const init_idx = *probing_iter; while (true) { auto const bucket_slots = storage_ref_[*probing_iter]; @@ -821,6 +836,7 @@ class open_addressing_ref_impl { if (group.any(state == detail::equal_result::EMPTY)) { return false; } ++probing_iter; + if (*probing_iter == init_idx) { return false; } } } @@ -840,7 +856,8 @@ class open_addressing_ref_impl { [[nodiscard]] __device__ const_iterator find(ProbeKey const& key) const noexcept { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); - auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent()); + auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent()); + auto const init_idx = *probing_iter; while (true) { // TODO atomic_ref::load if insert operator is present @@ -859,6 +876,7 @@ class open_addressing_ref_impl { } } ++probing_iter; + if (*probing_iter == init_idx) { return this->end(); } } } @@ -879,7 +897,8 @@ class open_addressing_ref_impl { [[nodiscard]] __device__ const_iterator find( cooperative_groups::thread_block_tile const& group, ProbeKey const& key) const noexcept { - auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent()); + auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent()); + auto const init_idx = *probing_iter; while (true) { auto const bucket_slots = storage_ref_[*probing_iter]; @@ -908,6 +927,7 @@ class open_addressing_ref_impl { if (group.any(state == detail::equal_result::EMPTY)) { return this->end(); } ++probing_iter; + if (*probing_iter == init_idx) { return this->end(); } } } @@ -926,8 +946,9 @@ class open_addressing_ref_impl { if constexpr (not allows_duplicates) { return static_cast(this->contains(key)); } else { - auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent()); - size_type count = 0; + auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent()); + auto const init_idx = *probing_iter; + size_type count = 0; while (true) { // TODO atomic_ref::load if insert operator is present @@ -942,6 +963,7 @@ class open_addressing_ref_impl { } } ++probing_iter; + if (*probing_iter == init_idx) { return count; } } } } @@ -960,8 +982,9 @@ class open_addressing_ref_impl { [[nodiscard]] __device__ size_type count( cooperative_groups::thread_block_tile const& group, ProbeKey const& key) const noexcept { - auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent()); - size_type count = 0; + auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent()); + auto const init_idx = *probing_iter; + size_type count = 0; while (true) { auto const bucket_slots = storage_ref_[*probing_iter]; @@ -978,6 +1001,7 @@ class open_addressing_ref_impl { if (group.any(state == detail::equal_result::EMPTY)) { return count; } ++probing_iter; + if (*probing_iter == init_idx) { return count; } } } @@ -1177,6 +1201,7 @@ class open_addressing_ref_impl { auto const& probe_key = *(input_probe + idx); auto probing_iter = this->probing_scheme_(probing_tile, probe_key, this->storage_ref_.bucket_extent()); + auto const init_idx = *probing_iter; bool running = true; [[maybe_unused]] bool found_match = false; @@ -1277,6 +1302,7 @@ class open_addressing_ref_impl { // onto the next probing bucket ++probing_iter; + if (*probing_iter == init_idx) { running = false; } } // while running } // if active_flag @@ -1305,7 +1331,8 @@ class open_addressing_ref_impl { __device__ void for_each(ProbeKey const& key, CallbackOp&& callback_op) const noexcept { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); - auto probing_iter = this->probing_scheme_(key, this->storage_ref_.bucket_extent()); + auto probing_iter = this->probing_scheme_(key, this->storage_ref_.bucket_extent()); + auto const init_idx = *probing_iter; while (true) { // TODO atomic_ref::load if insert operator is present @@ -1325,6 +1352,7 @@ class open_addressing_ref_impl { } } ++probing_iter; + if (*probing_iter == init_idx) { return; } } } @@ -1352,8 +1380,9 @@ class open_addressing_ref_impl { ProbeKey const& key, CallbackOp&& callback_op) const noexcept { - auto probing_iter = this->probing_scheme_(group, key, this->storage_ref_.bucket_extent()); - bool empty = false; + auto probing_iter = this->probing_scheme_(group, key, this->storage_ref_.bucket_extent()); + auto const init_idx = *probing_iter; + bool empty = false; while (true) { // TODO atomic_ref::load if insert operator is present @@ -1378,6 +1407,7 @@ class open_addressing_ref_impl { if (group.any(empty)) { return; } ++probing_iter; + if (*probing_iter == init_idx) { return; } } } @@ -1414,8 +1444,9 @@ class open_addressing_ref_impl { CallbackOp&& callback_op, SyncOp&& sync_op) const noexcept { - auto probing_iter = this->probing_scheme_(group, key, this->storage_ref_.bucket_extent()); - bool empty = false; + auto probing_iter = this->probing_scheme_(group, key, this->storage_ref_.bucket_extent()); + auto const init_idx = *probing_iter; + bool empty = false; while (true) { // TODO atomic_ref::load if insert operator is present @@ -1441,6 +1472,7 @@ class open_addressing_ref_impl { if (group.any(empty)) { return; } ++probing_iter; + if (*probing_iter == init_idx) { return; } } } diff --git a/include/cuco/detail/static_map/static_map_ref.inl b/include/cuco/detail/static_map/static_map_ref.inl index f06f03fdb..f53a62398 100644 --- a/include/cuco/detail/static_map/static_map_ref.inl +++ b/include/cuco/detail/static_map/static_map_ref.inl @@ -493,6 +493,7 @@ class operator_impl< auto& probing_scheme = ref_.impl_.probing_scheme(); auto storage_ref = ref_.impl_.storage_ref(); auto probing_iter = probing_scheme(key, storage_ref.bucket_extent()); + auto const init_idx = *probing_iter; while (true) { auto const bucket_slots = storage_ref[*probing_iter]; @@ -514,6 +515,7 @@ class operator_impl< } } ++probing_iter; + if (*probing_iter == init_idx) { return; } } } @@ -539,6 +541,7 @@ class operator_impl< auto& probing_scheme = ref_.impl_.probing_scheme(); auto storage_ref = ref_.impl_.storage_ref(); auto probing_iter = probing_scheme(group, key, storage_ref.bucket_extent()); + auto const init_idx = *probing_iter; while (true) { auto const bucket_slots = storage_ref[*probing_iter]; @@ -578,6 +581,7 @@ class operator_impl< if (group.shfl(status, src_lane)) { return; } } else { ++probing_iter; + if (*probing_iter == init_idx) { return; } } } } @@ -855,6 +859,7 @@ class operator_impl< auto& probing_scheme = ref_.impl_.probing_scheme(); auto storage_ref = ref_.impl_.storage_ref(); auto probing_iter = probing_scheme(key, storage_ref.bucket_extent()); + auto const init_idx = *probing_iter; auto const empty_value = ref_.empty_value_sentinel(); // wait for payload only when init != sentinel and insert strategy is not `packed_cas` @@ -894,6 +899,7 @@ class operator_impl< } } ++probing_iter; + if (*probing_iter == init_idx) { return false; } } } @@ -929,6 +935,7 @@ class operator_impl< auto& probing_scheme = ref_.impl_.probing_scheme(); auto storage_ref = ref_.impl_.storage_ref(); auto probing_iter = probing_scheme(group, key, storage_ref.bucket_extent()); + auto const init_idx = *probing_iter; auto const empty_value = ref_.empty_value_sentinel(); // wait for payload only when init != sentinel and insert strategy is not `packed_cas` @@ -987,6 +994,7 @@ class operator_impl< } } else { ++probing_iter; + if (*probing_iter == init_idx) { return false; } } } }