Skip to content

Commit

Permalink
Implement OA retrieve(_outer) and its multiset API (#537)
Browse files Browse the repository at this point in the history
  • Loading branch information
sleeepyjack authored Oct 18, 2024
1 parent dafcf45 commit 6816740
Show file tree
Hide file tree
Showing 20 changed files with 1,189 additions and 16 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,13 @@ We plan to add many GPU-accelerated, concurrent data structures to `cuCollection
#### Examples:
- [Host-bulk APIs](https://github.com/NVIDIA/cuCollections/blob/dev/examples/static_multimap/host_bulk_example.cu) (see [live example in godbolt](https://godbolt.org/clientstate/eJylVgtv2zYQ_isHDUXtVJYfaFDEjQN4bYoZK5whTlsUcaHQFG0TkUmNpOx6hv_77ijJlpsM67AWiCHe-7vvjtwFVlgrtbJB_34XyCTod8MgZWqRs4UI-gHPExaEgdW54fTdPpsqOIN3OtsauVg6aPAm9Dq9bgv_vA5h_Hn0fjSEdze3f9zcDu9GN-OIDLzRR8mFsiKBXCXCgFsKGGaM408pCeGzMJQN9KIONEhhGpSyadB8671sdQ4rtgWlHeRWoBtpYS5TAeI7F5kDqYDrVZZKpriAjXRLH6r049OBr6UTPXMM9RlaZPg1r2sCc4fU6d_Suazfbm82m4j5tCNtFu20ULbtj6N31-PJdQtTP5h9UinCC0b8mUuDhc-2wDLMjLMZ5puyDWgDbGEEypymzDdGOqkWIVg9dxtmhPeTSOuMnOXuBLwqT6y_roDwMYXADScwmkwD-HU4GU1C7-fL6O63m0938GV4ezsc342uJ3Bzi80avx9Rq_DrAwzHX-H30fh9CAKhw1Die2aoCkxVEqwiKTCcCHGSxlwXadlMcDmXHCoawUKvhVFYFmTCrGRBOEwy8X5SuZKOOX_2pDgfqj1VU_WLVDzNEwGXPOe6bcmEx6s8dXLFsojny6tTNbc0uXXtRKzRVbwW3GkTkdITFemEYShtc50rgj-uTp7Xt9hPgeR6XuoMUxbBWEU_ZuQrtf5QKocElKqx1jJpTtUO60Q6E0aPYhu7bSaQcgOkhHt7FK1ZmotCWIlIeLAQq8xtY_q0ggoRqWfuAFpd76VmX6gWBwflUpFUObbDUevBuqTft_IvNIQxqpx3XnY6nVKt3cZdgJom5w5bClU_irnrdjoh6oJNNUqLGlpd6rz_sb7TPhOqoe2zKb1WOdkIxtoVTOM0dxKVkfF8qVEFHpXekNcNTXqaIihWGIc5-rjo1IalQ6InzkWqWQJzRmSgeT_vvIh8ucgprPKUVJcVsGENuSvsW-bbBQjHGfTC0vgA_u5pG_anSt7b7rkW7PclsAWZ-v0T-l4WTjImzfO5XQHJbGPcrPXHCIYAMqhoS3V7tQiuBS4yggmW1I2NxuIcrjYblcbX0SIKYbdDPLGGXTfs4k8URYAnvXMiQnEM_qMU7usFHMahUWBWna_Yo4ifDNwlsvqq0WmGP69clEIEbTQru-JsJhY4YdXZ_TeI4xLOuMFyWrhN2OF2drlR8G_Q7iS8gMYY2tBrhiD3b2FfA3nkeYcTgBz0wWk8tWct0oWUaEkV7GycZlcmK1SC-R89TurdMnomPJsJdyC4iy68vnh5cXGx_2fCVKVceevY6XguMZDnx9GmYkajrnRM8OTU5xlCp84wakwxzJrz3BjyZSlxnzJea_eYdbE3mrRZ6NLFmwoFHqBqh5fu7vDsIdZ4mZkHsPl8Lr-jaoK3phMUhbkfQlEkhg8C1fLspfXQ9c58k_0mA3SX5S4mmuAO8zcGZV2E-em6m_9zOpMYr1Is1zZq-dSQ_ICRChI9bfbe702LgRAFZq3mktF9f9icdeI9HCI9VAwljGjSH-pl3ctvD5BoYdVLhzc9PiPCmi1KIysQwAQGg2evjIcDzGVBgnQ9vjhXRoq1-I8Qh0eQKrWCrD7KwWnZyFrU1lPDI7DEKacdS0Hlq5nwu9_nVS08sEudp_hEQ-6N4RX4QS-qs35o8b3ZOCHR4MdkBnAwPBClXC50XeKU4kuanqb4WDXHB3eg1px3e-d5F8U6c8VrPGjhvTvgr15130CLGb4c2FX8pgOtFl7KDv84rFkkrZStZv6JnspZzSfnPMXDdfGexgO8odVjsA8rOVL1RI5MDvbf_P-_AYKkJA4=))

### `static_multiset`

`cuco::static_multiset` is a fixed-size container that supports storing equivalent keys. It uses double hashing by default and supports switching to linear probing. See the Doxygen documentation in `static_multiset.cuh` for more detailed information.

#### Examples:
- [Host-bulk APIs](https://github.com/NVIDIA/cuCollections/blob/dev/examples/static_multiset/host_bulk_example.cu) (see [live example in godbolt](https://godbolt.org/clientstate/eJyVVw1vGkcQ_SuTqypDcnxZjSIRuyq1HRU1wpFxEkUhwsveACvf3dLdPQi1_N87s3sHhz_a1JZsuJ19--bNzFu4iyxaq3Ruo_7Xu0glUb8XR6nIF4VYYNSPZJGIKI6sLozk952XkxxewplebY1aLB00ZBOOu8e_xDD6NDwfDuDs8urD5dXgeng5anOsj3-vJOYWEyjyBA24JcJgJST9K1di-ISGicBxuwsNDphE5dokar71KFtdQCa2kGsHhUWCURbmKkXA7xJXDlQOUmerVIlcImyUW_qjShxPB76UIHrmBMUL2rGid_N6JAi3o84_S-dW_U5ns9m0hafd1mbRSUOw7bwfnl2Mxhctor7b9jFPSVkw-FehDCU-24JYETMpZsQ3FRvQBsTCIK05zcw3RjmVL2Kweu42wqDHSZR1Rs0KdyBexZPyrweQfCIn4QZjGI4nEfw-GA_Hscf5PLz-4_LjNXweXF0NRtfDizFcXlGxRudDLhW9eweD0Rf4czg6jwFJOjoKv68MZ0FUFcuKSdBwjHhAY64DLbtCqeZKQtVBsNBrNDmlBSs0mQq9RiQTj5OqTDnh_LNHyfmjOpN8kv-kcpkWCcKJLKTuWN4ip1mROmXRtWWx_PUwzC1NYV0nwTVBTdconTZtDnoUMi9yyeeL9On1VC-oZM8sWiouUqe1H56vNNUERXawySdrfWTnZWiS33zzLil6OivS2yl-FyQzUkZheWYUzuEcMxLIGeGQZLIsa9muD5SgyjMMFZ4hYfBhaPcjeE3xvOofA7UXdUu6pcZb69CTc6Mzj-o3U4l8UMFTy_VNNFdRUxnLgqXqFuFGUaWMu6EW8afcGKRuxDXe-GaEHTVifIvbPR8qrModTaLKG2utkuYkv6PndB4fQ6FTt10hnBJB95Y1A-h04CJbuS3YVLuQgkFuUMxdmDB-bdb0ehLxQ5VjSmKsRVqgbbMC1KvhHdilLlLaRCJiSg1Cm2whl-U5bikca0FeQ7kY0FIWhoeULIj_rwoHiXCC09mTlVwmnhlA5jnlhYoHZdLq7RMZFdmMcYMq7ADEJGiJPBtU2qTft-pvnLoabl5kU7_hFF53j7rd7h7xjIyPTAAkO5QilWbClpZAsT9DqgUVUvAo8A5R0Jl7YF6dhlXC7rZfv32SxB79NCxKVGljx6pTx2nWuXH_FpKrBtwN3ppJ4hQFgd5UqDdlaUMTtHogrO9IL2dQqtLTK892QAQPp-Ckqsevu-67qw6Iyy27-tw9rtT9fY04TTFpyqTDqO8qdteNoRfDcQztdgzqvtzwGY8Mux7z97bOFQWkO4O3gduQIZHBw6Ycv5z9_6gm4PGRd3RyDbcbGAh-0-8fGFotTw6sV-HYX5i7bRX5hgec4YJmrhkHeMwTft2tVWvoSVOp0jTkSiOogzMIu-T6ceDOfkOOz2IHKjvY4FsiQ6i5Al2E5AN88wWp6gpUkxpGMAhIg8gwFYX_S2fXlUXuQoN5cMMaWeaUCSeXpQ9Z7nz-qECFenBu-5kZYVxMprpwNJFTXqNx2V9XvPwj_MZUZDZ7dl9yNtr_I71Qnroyeoa28QSXw974DxivxPM4JdOr0vR9yxyId2g18PWA3ZQyjmH6DU45rFbD6hL5l4Y9ANqHHNKunjefcbPqnIe1esQSWk8fWEqg5tBoPAN2-lQ_NP39-vyWapabTbgL2gS3pTA4OaHrbVxI6lX7At4ReLIb1faECEU-23vAlK67p7e_Eyp9AWOdofOl2pD30QdQnS_qCCE5IlmYHPiquaevAvwBmz61mP03hihfS9k7fl30aFmvXPg6EbXozFP56lXvDbSEkctTm03fdKHVogvH0R9HumLSSkU2898xUjWrYUopU3q4Dt8K6AHdH_ltdB9X6-TjB-vUyNH9N__7D2OnfWU=))

### `dynamic_map`

`cuco::dynamic_map` links together multiple `cuco::static_map`s to provide a hash table that can grow as key-value pairs are inserted. It currently only provides host-bulk APIs. See the Doxygen documentation in `dynamic_map.cuh` for more detailed information.
Expand Down
1 change: 1 addition & 0 deletions benchmarks/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ ConfigureBench(STATIC_MAP_BENCH
# - static_multiset benchmarks --------------------------------------------------------------------
ConfigureBench(STATIC_MULTISET_BENCH
static_multiset/contains_bench.cu
static_multiset/retrieve_bench.cu
static_multiset/count_bench.cu
static_multiset/find_bench.cu
static_multiset/insert_bench.cu)
Expand Down
87 changes: 87 additions & 0 deletions benchmarks/static_multiset/retrieve_bench.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <benchmark_defaults.hpp>
#include <benchmark_utils.hpp>

#include <cuco/static_multiset.cuh>
#include <cuco/utility/key_generator.cuh>

#include <nvbench/nvbench.cuh>

#include <thrust/device_vector.h>
#include <thrust/transform.h>

using namespace cuco::benchmark;
using namespace cuco::utility;

/**
* @brief A benchmark evaluating `cuco::static_multiset::retrieve` performance
*/
template <typename Key, typename Dist>
void static_multiset_retrieve(nvbench::state& state, nvbench::type_list<Key, Dist>)
{
auto const num_keys = state.get_int64_or_default("NumInputs", defaults::N);
auto const occupancy = state.get_float64_or_default("Occupancy", defaults::OCCUPANCY);
auto const matching_rate = state.get_float64_or_default("MatchingRate", defaults::MATCHING_RATE);

std::size_t const size = num_keys / occupancy;

thrust::device_vector<Key> keys(num_keys);

key_generator gen;
gen.generate(dist_from_state<Dist>(state), keys.begin(), keys.end());

gen.dropout(keys.begin(), keys.end(), matching_rate);

state.add_element_count(num_keys);

cuco::static_multiset<Key> set{size, cuco::empty_key<Key>{-1}};
set.insert(keys.begin(), keys.end());

auto const output_size = set.count(keys.begin(), keys.end());
thrust::device_vector<Key> output_match(output_size);
auto output_probe_begin = thrust::discard_iterator{};

state.exec(nvbench::exec_tag::sync, [&](nvbench::launch& launch) {
set.retrieve(
keys.begin(), keys.end(), output_probe_begin, output_match.begin(), {launch.get_stream()});
});
}

NVBENCH_BENCH_TYPES(static_multiset_retrieve,
NVBENCH_TYPE_AXES(defaults::KEY_TYPE_RANGE,
nvbench::type_list<distribution::uniform>))
.set_name("static_multiset_retrieve_uniform_occupancy")
.set_type_axes_names({"Key", "Distribution"})
.set_max_noise(defaults::MAX_NOISE)
.add_float64_axis("Occupancy", defaults::OCCUPANCY_RANGE);

NVBENCH_BENCH_TYPES(static_multiset_retrieve,
NVBENCH_TYPE_AXES(defaults::KEY_TYPE_RANGE,
nvbench::type_list<distribution::uniform>))
.set_name("static_multiset_retrieve_uniform_matching_rate")
.set_type_axes_names({"Key", "Distribution"})
.set_max_noise(defaults::MAX_NOISE)
.add_float64_axis("MatchingRate", defaults::MATCHING_RATE_RANGE);

NVBENCH_BENCH_TYPES(static_multiset_retrieve,
NVBENCH_TYPE_AXES(defaults::KEY_TYPE_RANGE,
nvbench::type_list<distribution::uniform>))
.set_name("static_multiset_retrieve_uniform_multiplicity")
.set_type_axes_names({"Key", "Distribution"})
.set_max_noise(defaults::MAX_NOISE)
.add_int64_axis("Multiplicity", defaults::MULTIPLICITY_RANGE);
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ ConfigureExample(STATIC_SET_DEVICE_REF_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/stat
ConfigureExample(STATIC_SET_DEVICE_SUBSETS_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/static_set/device_subsets_example.cu")
ConfigureExample(STATIC_SET_SHARED_MEMORY_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/static_set/shared_memory_example.cu")
ConfigureExample(STATIC_SET_MAPPING_TABLE_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/static_set/mapping_table_example.cu")
ConfigureExample(STATIC_MULTISET_HOST_BULK_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/static_multiset/host_bulk_example.cu")
ConfigureExample(STATIC_MAP_HOST_BULK_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/static_map/host_bulk_example.cu")
ConfigureExample(STATIC_MAP_DEVICE_SIDE_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/static_map/device_ref_example.cu")
ConfigureExample(STATIC_MAP_CUSTOM_TYPE_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/static_map/custom_type_example.cu")
Expand Down
82 changes: 82 additions & 0 deletions examples/static_multiset/host_bulk_example.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cuco/static_multiset.cuh>

#include <thrust/device_vector.h>
#include <thrust/functional.h>
#include <thrust/logical.h>
#include <thrust/sequence.h>

#include <iostream>
#include <limits>

/**
* @file host_bulk_example.cu
* @brief Demonstrates usage of the static_multiset "bulk" host APIs.
*
* The bulk APIs are only invocable from the host and are used for doing operations like `insert` or
* `retrieve` on a multiset of keys.
*
*/
int main(void)
{
using key_type = int;

// Empty slots are represented by reserved "sentinel" values. These values should be selected such
// that they never occur in your input data.
key_type constexpr empty_key_sentinel = -1;

// Number of keys to be inserted
std::size_t constexpr num_keys = 50'000;

// Compute capacity based on a 50% load factor
auto constexpr load_factor = 0.5;
std::size_t const capacity = std::ceil(num_keys / load_factor);

// Constructs a set with at least `capacity` slots using -1 as the empty keys sentinel.
cuco::static_multiset<key_type> multiset{capacity, cuco::empty_key{empty_key_sentinel}};

// Create a sequence of keys {0, 1, 2, .., i}
// We're going to insert each key twice so we only need 'num_keys / 2' distinct keys.
thrust::device_vector<key_type> keys(num_keys / 2);
thrust::sequence(keys.begin(), keys.end(), 0);

// Inserts all keys into the hash set
multiset.insert(keys.begin(), keys.end());
// Insert the same set of keys again, so each distinct key should occur twice in the multiset
multiset.insert(keys.begin(), keys.end());

// Counts the occurrences of matching keys contained in the multiset.
std::size_t const counted_output_size = multiset.count(keys.begin(), keys.end());

// Storage for result
thrust::device_vector<key_type> output_probes(counted_output_size);
thrust::device_vector<key_type> output_matches(counted_output_size);

// Retrieve all matching keys
auto const [output_probes_end, _] =
multiset.retrieve(keys.begin(), keys.end(), output_probes.begin(), output_matches.begin());
std::size_t const retrieved_output_size = output_probes_end - output_probes.begin();

if ((retrieved_output_size == counted_output_size) and (retrieved_output_size == num_keys)) {
std::cout << "Success! Found all keys.\n";
} else {
std::cout << "Fail! Something went wrong.\n";
}

return 0;
}
2 changes: 1 addition & 1 deletion include/cuco/detail/extent/extent.inl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

#include <cuco/detail/error.hpp>
#include <cuco/detail/prime.hpp> // TODO move to detail/extent/
#include <cuco/detail/utility/math.hpp>
#include <cuco/detail/utility/math.cuh>
#include <cuco/detail/utils.hpp>
#include <cuco/utility/fast_int.cuh>

Expand Down
71 changes: 71 additions & 0 deletions include/cuco/detail/open_addressing/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,77 @@ CUCO_KERNEL __launch_bounds__(BlockSize) void find(InputIt first,
}
}

/**
* @brief Retrieves the equivalent container elements of all keys in the range `[input_probe,
* input_probe + n)`.
*
* If key `k = *(input_probe + i)` has one or more matches in the container, copies `k` to
* `output_probe` and associated slot contents to `output_match`, respectively. The output order is
* unspecified.
*
* @tparam IsOuter Flag indicating whether it's an outer count or not
* @tparam block_size The size of the thread block
* @tparam InputProbeIt Device accessible input iterator
* @tparam OutputProbeIt Device accessible input iterator whose `value_type` is
* convertible to the `InputProbeIt`'s `value_type`
* @tparam OutputMatchIt Device accessible input iterator whose `value_type` is
* convertible to the container's `value_type`
* @tparam AtomicCounter Integral atomic type that follows the same semantics as
* `cuda::(std::)atomic(_ref)`
* @tparam Ref Type of non-owning device ref allowing access to storage
*
* @param input_probe Beginning of the sequence of input keys
* @param n Number of the keys to query
* @param output_probe Beginning of the sequence of keys corresponding to matching elements in
* `output_match`
* @param output_match Beginning of the sequence of matching elements
* @param atomic_counter Pointer to an atomic object of integral type that is used to count the
* number of output elements
* @param ref Non-owning container device ref used to access the slot storage
*/
template <bool IsOuter,
int32_t BlockSize,
class InputProbeIt,
class OutputProbeIt,
class OutputMatchIt,
class AtomicCounter,
class Ref>
CUCO_KERNEL __launch_bounds__(BlockSize) void retrieve(InputProbeIt input_probe,
cuco::detail::index_type n,
OutputProbeIt output_probe,
OutputMatchIt output_match,
AtomicCounter* atomic_counter,
Ref ref)
{
namespace cg = cooperative_groups;

auto const block = cg::this_thread_block();
auto constexpr tiles_in_block = BlockSize / Ref::cg_size;
// make sure all but the last block are always occupied
auto const items_per_block = detail::int_div_ceil(n, tiles_in_block * gridDim.x) * tiles_in_block;

auto const block_begin_offset = block.group_index().x * items_per_block;
auto const block_end_offset = min(n, block_begin_offset + items_per_block);

if (block_begin_offset < block_end_offset) {
if constexpr (IsOuter) {
ref.retrieve_outer<BlockSize>(block,
input_probe + block_begin_offset,
input_probe + block_end_offset,
output_probe,
output_match,
*atomic_counter);
} else {
ref.retrieve<BlockSize>(block,
input_probe + block_begin_offset,
input_probe + block_end_offset,
output_probe,
output_match,
*atomic_counter);
}
}
}

/**
* @brief Inserts all elements in the range `[first, last)`.
*
Expand Down
Loading

0 comments on commit 6816740

Please sign in to comment.