Skip to content

Commit

Permalink
Add retrieve overloads to meet libcudf requirements
Browse files Browse the repository at this point in the history
  • Loading branch information
PointKernel committed Oct 20, 2024
1 parent 6816740 commit 4c1ad2c
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 10 deletions.
40 changes: 37 additions & 3 deletions include/cuco/detail/static_multiset/static_multiset.inl
Original file line number Diff line number Diff line change
Expand Up @@ -304,17 +304,51 @@ template <class Key,
class ProbingScheme,
class Allocator,
class Storage>
template <class InputProbeIt, class OutputProbeIt, class OutputMatchIt>
template <class InputProbeIt,
class ProbeEqual,
class ProbeHash,
class OutputProbeIt,
class OutputMatchIt>
std::pair<OutputProbeIt, OutputMatchIt>
static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::retrieve(
InputProbeIt first,
InputProbeIt last,
ProbeEqual const& probe_equal,
ProbeHash const& probe_hash,
OutputProbeIt output_probe,
OutputMatchIt output_match,
cuda::stream_ref stream) const
{
auto const probe_ref =
this->ref(op::retrieve).rebind_key_eq(probe_equal).rebind_hash_function(probe_hash);
return this->impl_->retrieve(first, last, output_probe, output_match, probe_ref, stream);
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <class InputProbeIt,
class ProbeEqual,
class ProbeHash,
class OutputProbeIt,
class OutputMatchIt>
std::pair<OutputProbeIt, OutputMatchIt>
static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::retrieve_outer(
InputProbeIt first,
InputProbeIt last,
ProbeEqual const& probe_equal,
ProbeHash const& probe_hash,
OutputProbeIt output_probe,
OutputMatchIt output_match,
cuda::stream_ref stream) const
{
return this->impl_->retrieve_outer(
first, last, output_probe, output_match, this->ref(op::retrieve), stream);
auto const probe_ref =
this->ref(op::retrieve).rebind_key_eq(probe_equal).rebind_hash_function(probe_hash);
return this->impl_->retrieve_outer(first, last, output_probe, output_match, probe_ref, stream);
}

template <class Key,
Expand Down
55 changes: 54 additions & 1 deletion include/cuco/static_multiset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,49 @@ class static_multiset {
OutputMatchIt output_match,
cuda::stream_ref stream = {}) const;

/**
* @brief Retrieves all the slots corresponding to all keys in the range `[first, last)`.
*
* If key `k = *(first + i)` exists in the container, copies `k` to `output_probe` and associated
* slot contents to `output_match`, respectively. The output order is unspecified.
*
* Behavior is undefined if the size of the output range exceeds the number of retrieved slots.
* Use `count()` to determine the size of the output range.
*
* This function synchronizes the given CUDA stream.
*
* @tparam InputProbeIt Device accessible input iterator
* @tparam ProbeEqual Binary callable equal type
* @tparam ProbeHash Unary callable hasher type that can be constructed from
* @tparam OutputProbeIt Device accessible input iterator whose `value_type` is
* convertible to the `InputProbeIt`'s `value_type`
* @tparam OutputMatchIt Device accessible input iterator whose `value_type` is
* convertible to the container's `value_type`
*
* @param first Beginning of the input sequence of keys
* @param last End of the input sequence of keys
* @param probe_equal The binary function to compare set keys and probe keys for equality
* @param probe_hash The unary function to hash probe keys
* @param output_probe Beginning of the sequence of keys corresponding to matching elements in
* `output_match`
* @param output_match Beginning of the sequence of matching elements
* @param stream CUDA stream this operation is executed in
*
* @return Iterator pair indicating the the end of the output sequences
*/
template <class InputProbeIt,
class ProbeEqual,
class ProbeHash,
class OutputProbeIt,
class OutputMatchIt>
std::pair<OutputProbeIt, OutputMatchIt> retrieve(InputProbeIt first,
InputProbeIt last,
ProbeEqual const& probe_equal,
ProbeHash const& probe_hash,
OutputProbeIt output_probe,
OutputMatchIt output_match,
cuda::stream_ref stream = {}) const;

/**
* @brief Retrieves all the slots corresponding to all keys in the range `[first, last)`.
*
Expand All @@ -524,23 +567,33 @@ class static_multiset {
* This function synchronizes the given CUDA stream.
*
* @tparam InputProbeIt Device accessible input iterator
* @tparam ProbeEqual Binary callable equal type
* @tparam ProbeHash Unary callable hasher type that can be constructed from
* @tparam OutputProbeIt Device accessible input iterator whose `value_type` is
* convertible to the `InputProbeIt`'s `value_type`
* @tparam OutputMatchIt Device accessible input iterator whose `value_type` is
* convertible to the container's `value_type`
*
* @param first Beginning of the input sequence of keys
* @param last End of the input sequence of keys
* @param probe_equal The binary function to compare set keys and probe keys for equality
* @param probe_hash The unary function to hash probe keys
* @param output_probe Beginning of the sequence of keys corresponding to matching elements in
* `output_match`
* @param output_match Beginning of the sequence of matching elements
* @param stream CUDA stream this operation is executed in
*
* @return Iterator pair indicating the the end of the output sequences
*/
template <class InputProbeIt, class OutputProbeIt, class OutputMatchIt>
template <class InputProbeIt,
class ProbeEqual,
class ProbeHash,
class OutputProbeIt,
class OutputMatchIt>
std::pair<OutputProbeIt, OutputMatchIt> retrieve_outer(InputProbeIt first,
InputProbeIt last,
ProbeEqual const& probe_equal,
ProbeHash const& probe_hash,
OutputProbeIt output_probe,
OutputMatchIt output_match,
cuda::stream_ref stream = {}) const;
Expand Down
2 changes: 1 addition & 1 deletion tests/static_multiset/large_input_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void test_unique_sequence(Set& set, typename Set::value_type* res_begin, std::si
set.retrieve(keys_begin, keys_end, thrust::make_discard_iterator(), res_begin);
REQUIRE(static_cast<std::size_t>(std::distance(res_begin, res_end)) == num_keys);

thrust::sort(res_begin, res_end);
thrust::sort(thrust::device, res_begin, res_end);

REQUIRE(cuco::test::equal(res_begin, res_end, keys_begin, thrust::equal_to<Key>{}));
}
Expand Down
18 changes: 13 additions & 5 deletions tests/static_multiset/retrieve_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,12 @@ void test_outer(Container& container, std::size_t num_keys)

SECTION("Non-inserted keys should output sentinels.")
{
auto const [probed_end, matched_end] = container.retrieve_outer(
query_keys.begin(), query_keys.end(), probed_keys.begin(), matched_keys.begin());
auto const [probed_end, matched_end] = container.retrieve_outer(query_keys.begin(),
query_keys.end(),
container.key_eq(),
container.hash_function(),
probed_keys.begin(),
matched_keys.begin());
REQUIRE(static_cast<std::size_t>(std::distance(probed_keys.begin(), probed_end)) ==
num_keys * 2ull);
REQUIRE(static_cast<std::size_t>(std::distance(matched_keys.begin(), matched_end)) ==
Expand All @@ -112,8 +116,12 @@ void test_outer(Container& container, std::size_t num_keys)

SECTION("All inserted keys should be contained.")
{
auto const [probed_end, matched_end] = container.retrieve_outer(
query_keys.begin(), query_keys.end(), probed_keys.begin(), matched_keys.begin());
auto const [probed_end, matched_end] = container.retrieve_outer(query_keys.begin(),
query_keys.end(),
container.key_eq(),
container.hash_function(),
probed_keys.begin(),
matched_keys.begin());
thrust::sort_by_key(
probed_keys.begin(), probed_end, matched_keys.begin(), thrust::less<key_type>());

Expand Down Expand Up @@ -160,4 +168,4 @@ TEMPLATE_TEST_CASE_SIG(
test_multiplicity(set, num_keys, 2); // each key occurs twice
test_multiplicity(set, num_keys, 11);
test_outer(set, num_keys);
}
}

0 comments on commit 4c1ad2c

Please sign in to comment.