Skip to content

Commit

Permalink
Add static_multimap_ref::for_each
Browse files Browse the repository at this point in the history
  • Loading branch information
sleeepyjack committed Sep 5, 2024
1 parent 4454de4 commit 7b7b553
Show file tree
Hide file tree
Showing 3 changed files with 284 additions and 1 deletion.
109 changes: 109 additions & 0 deletions include/cuco/detail/static_multimap/static_multimap_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,115 @@ class operator_impl<
}
};

template <typename Key,
typename T,
cuda::thread_scope Scope,
typename KeyEqual,
typename ProbingScheme,
typename StorageRef,
typename... Operators>
class operator_impl<
op::for_each_tag,
static_multimap_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>> {
using base_type = static_multimap_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef>;
using ref_type =
static_multimap_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>;

static constexpr auto cg_size = base_type::cg_size;

public:
/**
* @brief Executes a callback on every element in the container with key equivalent to the probe
* key.
*
* @note Passes an un-incrementable input iterator to the element whose key is equivalent to
* `key` to the callback.
*
* @tparam ProbeKey Probe key type
* @tparam CallbackOp Unary callback functor or device lambda
*
* @param key The key to search for
* @param callback_op Function to call on every element found
*/
template <class ProbeKey, class CallbackOp>
__device__ void for_each(ProbeKey const& key, CallbackOp&& callback_op) const noexcept
{
// CRTP: cast `this` to the actual ref type
auto const& ref_ = static_cast<ref_type const&>(*this);
ref_.impl_.for_each(key, std::forward<CallbackOp>(callback_op));
}

/**
* @brief Executes a callback on every element in the container with key equivalent to the probe
* key.
*
* @note Passes an un-incrementable input iterator to the element whose key is equivalent to
* `key` to the callback.
*
* @note This function uses cooperative group semantics, meaning that any thread may call the
* callback if it finds a matching element. If multiple elements are found within the same group,
* each thread with a match will call the callback with its associated element.
*
* @note Synchronizing `group` within `callback_op` is undefined behavior.
*
* @tparam ProbeKey Probe key type
* @tparam CallbackOp Unary callback functor or device lambda
*
* @param group The Cooperative Group used to perform this operation
* @param key The key to search for
* @param callback_op Function to call on every element found
*/
template <class ProbeKey, class CallbackOp>
__device__ void for_each(cooperative_groups::thread_block_tile<cg_size> const& group,
ProbeKey const& key,
CallbackOp&& callback_op) const noexcept
{
// CRTP: cast `this` to the actual ref type
auto const& ref_ = static_cast<ref_type const&>(*this);
ref_.impl_.for_each(group, key, std::forward<CallbackOp>(callback_op));
}

/**
* @brief Executes a callback on every element in the container with key equivalent to the probe
* key and can additionally perform work that requires synchronizing the Cooperative Group
* performing this operation.
*
* @note Passes an un-incrementable input iterator to the element whose key is equivalent to
* `key` to the callback.
*
* @note This function uses cooperative group semantics, meaning that any thread may call the
* callback if it finds a matching element. If multiple elements are found within the same group,
* each thread with a match will call the callback with its associated element.
*
* @note Synchronizing `group` within `callback_op` is undefined behavior.
*
* @note The `sync_op` function can be used to perform work that requires synchronizing threads in
* `group` inbetween probing steps, where the number of probing steps performed between
* synchronization points is capped by `window_size * cg_size`. The functor will be called right
* after the current probing window has been traversed.
*
* @tparam ProbeKey Probe key type
* @tparam CallbackOp Unary callback functor or device lambda
* @tparam SyncOp Functor or device lambda which accepts the current `group` object
*
* @param group The Cooperative Group used to perform this operation
* @param key The key to search for
* @param callback_op Function to call on every element found
* @param sync_op Function that is allowed to synchronize `group` inbetween probing windows
*/
template <class ProbeKey, class CallbackOp, class SyncOp>
__device__ void for_each(cooperative_groups::thread_block_tile<cg_size> const& group,
ProbeKey const& key,
CallbackOp&& callback_op,
SyncOp&& sync_op) const noexcept
{
// CRTP: cast `this` to the actual ref type
auto const& ref_ = static_cast<ref_type const&>(*this);
ref_.impl_.for_each(
group, key, std::forward<CallbackOp>(callback_op), std::forward<SyncOp>(sync_op));
}
};

template <typename Key,
typename T,
cuda::thread_scope Scope,
Expand Down
3 changes: 2 additions & 1 deletion tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ ConfigureTest(STATIC_MULTIMAP_TEST
static_multimap/insert_if_test.cu
static_multimap/multiplicity_test.cu
static_multimap/non_match_test.cu
static_multimap/pair_function_test.cu)
static_multimap/pair_function_test.cu
static_multimap/for_each_test.cu)

###################################################################################################
# - dynamic_bitset tests --------------------------------------------------------------------------
Expand Down
173 changes: 173 additions & 0 deletions tests/static_multimap/for_each_test.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <test_utils.hpp>

#include <cuco/detail/utility/cuda.hpp>
#include <cuco/static_multimap.cuh>

#include <cuda/atomic>
#include <cuda/functional>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>

#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>

#include <catch2/catch_template_test_macros.hpp>

#include <cstddef>

template <class Ref, class InputIt, class AtomicErrorCounter>
CUCO_KERNEL void for_each_check_scalar(Ref ref,
InputIt first,
std::size_t n,
std::size_t multiplicity,
AtomicErrorCounter* error_counter)
{
static_assert(Ref::cg_size == 1, "Scalar test must have cg_size==1");
auto const loop_stride = cuco::detail::grid_stride();
auto idx = cuco::detail::global_thread_id();

while (idx < n) {
auto const& key = *(first + idx);
std::size_t matches = 0;
ref.for_each(key, [&] __device__(auto const slot) {
auto const [slot_key, slot_value] = slot;
if (ref.key_eq()(key, slot_key) and ref.key_eq()(slot_key, slot_value)) { matches++; }
});
if (matches != multiplicity) { error_counter->fetch_add(1, cuda::memory_order_relaxed); }
idx += loop_stride;
}
}

template <bool Synced, class Ref, class InputIt, class AtomicErrorCounter>
CUCO_KERNEL void for_each_check_cooperative(Ref ref,
InputIt first,
std::size_t n,
std::size_t multiplicity,
AtomicErrorCounter* error_counter)
{
auto const loop_stride = cuco::detail::grid_stride() / Ref::cg_size;
auto idx = cuco::detail::global_thread_id() / Ref::cg_size;
;

while (idx < n) {
auto const tile =
cooperative_groups::tiled_partition<Ref::cg_size>(cooperative_groups::this_thread_block());
auto const& key = *(first + idx);
std::size_t thread_matches = 0;
if constexpr (Synced) {
ref.for_each(
tile,
key,
[&] __device__(auto const slot) {
auto const [slot_key, slot_value] = slot;
if (ref.key_eq()(key, slot_key) and ref.key_eq()(slot_key, slot_value)) {
thread_matches++;
}
},
[] __device__(auto const& group) { group.sync(); });
} else {
ref.for_each(tile, key, [&] __device__(auto const slot) {
auto const [slot_key, slot_value] = slot;
if (ref.key_eq()(key, slot_key) and ref.key_eq()(slot_key, slot_value)) {
thread_matches++;
}
});
}
auto const tile_matches =
cooperative_groups::reduce(tile, thread_matches, cooperative_groups::plus<std::size_t>());
if (tile_matches != multiplicity and tile.thread_rank() == 0) {
error_counter->fetch_add(1, cuda::memory_order_relaxed);
}
idx += loop_stride;
}
}

TEMPLATE_TEST_CASE_SIG(
"static_multimap for_each tests",
"",
((typename Key, cuco::test::probe_sequence Probe, int CGSize), Key, Probe, CGSize),
(int32_t, cuco::test::probe_sequence::double_hashing, 1),
(int32_t, cuco::test::probe_sequence::double_hashing, 2),
(int64_t, cuco::test::probe_sequence::double_hashing, 1),
(int64_t, cuco::test::probe_sequence::double_hashing, 2),
(int32_t, cuco::test::probe_sequence::linear_probing, 1),
(int32_t, cuco::test::probe_sequence::linear_probing, 2),
(int64_t, cuco::test::probe_sequence::linear_probing, 1),
(int64_t, cuco::test::probe_sequence::linear_probing, 2))
{
constexpr size_t num_unique_keys{400};
constexpr size_t key_multiplicity{5};
constexpr size_t num_keys{num_unique_keys * key_multiplicity};

using probe = std::conditional_t<Probe == cuco::test::probe_sequence::linear_probing,
cuco::linear_probing<CGSize, cuco::default_hash_function<Key>>,
cuco::double_hashing<CGSize, cuco::default_hash_function<Key>>>;

auto set = cuco::experimental::static_multimap{num_keys,
cuco::empty_key<Key>{-1},
cuco::empty_value<Key>{-1},
{},
probe{},
{},
cuco::storage<2>{}};

auto unique_keys_begin = thrust::counting_iterator<Key>(0);
auto gen_duplicate_keys = cuda::proclaim_return_type<Key>(
[] __device__(auto const& k) { return static_cast<Key>(k % num_unique_keys); });
auto keys_begin = thrust::make_transform_iterator(unique_keys_begin, gen_duplicate_keys);

auto const pairs_begin = thrust::make_transform_iterator(
keys_begin, cuda::proclaim_return_type<cuco::pair<Key, Key>>([] __device__(auto i) {
return cuco::pair<Key, Key>{i, i};
}));

set.insert(pairs_begin, pairs_begin + num_keys);

using error_counter_type = cuda::atomic<std::size_t, cuda::thread_scope_system>;
error_counter_type* error_counter;
CUCO_CUDA_TRY(cudaMallocHost(&error_counter, sizeof(error_counter_type)));
new (error_counter) error_counter_type{0};

auto const grid_size = cuco::detail::grid_size(num_unique_keys, CGSize);
auto const block_size = cuco::detail::default_block_size();

// test scalar for_each
if constexpr (CGSize == 1) {
for_each_check_scalar<<<grid_size, block_size>>>(
set.ref(cuco::for_each), unique_keys_begin, num_unique_keys, key_multiplicity, error_counter);
CUCO_CUDA_TRY(cudaDeviceSynchronize());
REQUIRE(error_counter->load() == 0);
error_counter->store(0);
}

// test CG for_each
for_each_check_cooperative<false><<<grid_size, block_size>>>(
set.ref(cuco::for_each), unique_keys_begin, num_unique_keys, key_multiplicity, error_counter);
CUCO_CUDA_TRY(cudaDeviceSynchronize());
REQUIRE(error_counter->load() == 0);
error_counter->store(0);

// test synchronized CG for_each
for_each_check_cooperative<true><<<grid_size, block_size>>>(
set.ref(cuco::for_each), unique_keys_begin, num_unique_keys, key_multiplicity, error_counter);
CUCO_CUDA_TRY(cudaDeviceSynchronize());
REQUIRE(error_counter->load() == 0);

CUCO_CUDA_TRY(cudaFreeHost(error_counter));
}

0 comments on commit 7b7b553

Please sign in to comment.