diff --git a/include/cuco/detail/static_multiset/static_multiset.inl b/include/cuco/detail/static_multiset/static_multiset.inl index b88f0f6ff..76127d583 100644 --- a/include/cuco/detail/static_multiset/static_multiset.inl +++ b/include/cuco/detail/static_multiset/static_multiset.inl @@ -304,17 +304,51 @@ template -template +template +std::pair +static_multiset::retrieve( + InputProbeIt first, + InputProbeIt last, + ProbeEqual const& probe_equal, + ProbeHash const& probe_hash, + OutputProbeIt output_probe, + OutputMatchIt output_match, + cuda::stream_ref stream) const +{ + auto const probe_ref = + this->ref(op::retrieve).rebind_key_eq(probe_equal).rebind_hash_function(probe_hash); + return this->impl_->retrieve(first, last, output_probe, output_match, probe_ref, stream); +} + +template +template std::pair static_multiset::retrieve_outer( InputProbeIt first, InputProbeIt last, + ProbeEqual const& probe_equal, + ProbeHash const& probe_hash, OutputProbeIt output_probe, OutputMatchIt output_match, cuda::stream_ref stream) const { - return this->impl_->retrieve_outer( - first, last, output_probe, output_match, this->ref(op::retrieve), stream); + auto const probe_ref = + this->ref(op::retrieve).rebind_key_eq(probe_equal).rebind_hash_function(probe_hash); + return this->impl_->retrieve_outer(first, last, output_probe, output_match, probe_ref, stream); } template + std::pair retrieve(InputProbeIt first, + InputProbeIt last, + ProbeEqual const& probe_equal, + ProbeHash const& probe_hash, + OutputProbeIt output_probe, + OutputMatchIt output_match, + cuda::stream_ref stream = {}) const; + /** * @brief Retrieves all the slots corresponding to all keys in the range `[first, last)`. * @@ -524,6 +567,8 @@ class static_multiset { * This function synchronizes the given CUDA stream. * * @tparam InputProbeIt Device accessible input iterator + * @tparam ProbeEqual Binary callable equal type + * @tparam ProbeHash Unary callable hasher type that can be constructed from * @tparam OutputProbeIt Device accessible input iterator whose `value_type` is * convertible to the `InputProbeIt`'s `value_type` * @tparam OutputMatchIt Device accessible input iterator whose `value_type` is @@ -531,6 +576,8 @@ class static_multiset { * * @param first Beginning of the input sequence of keys * @param last End of the input sequence of keys + * @param probe_equal The binary function to compare set keys and probe keys for equality + * @param probe_hash The unary function to hash probe keys * @param output_probe Beginning of the sequence of keys corresponding to matching elements in * `output_match` * @param output_match Beginning of the sequence of matching elements @@ -538,9 +585,15 @@ class static_multiset { * * @return Iterator pair indicating the the end of the output sequences */ - template + template std::pair retrieve_outer(InputProbeIt first, InputProbeIt last, + ProbeEqual const& probe_equal, + ProbeHash const& probe_hash, OutputProbeIt output_probe, OutputMatchIt output_match, cuda::stream_ref stream = {}) const; diff --git a/tests/static_multiset/large_input_test.cu b/tests/static_multiset/large_input_test.cu index 015260676..2896a4a7f 100644 --- a/tests/static_multiset/large_input_test.cu +++ b/tests/static_multiset/large_input_test.cu @@ -46,7 +46,7 @@ void test_unique_sequence(Set& set, typename Set::value_type* res_begin, std::si set.retrieve(keys_begin, keys_end, thrust::make_discard_iterator(), res_begin); REQUIRE(static_cast(std::distance(res_begin, res_end)) == num_keys); - thrust::sort(res_begin, res_end); + thrust::sort(thrust::device, res_begin, res_end); REQUIRE(cuco::test::equal(res_begin, res_end, keys_begin, thrust::equal_to{})); } diff --git a/tests/static_multiset/retrieve_test.cu b/tests/static_multiset/retrieve_test.cu index 300c8dc6c..ad21333ba 100644 --- a/tests/static_multiset/retrieve_test.cu +++ b/tests/static_multiset/retrieve_test.cu @@ -94,8 +94,12 @@ void test_outer(Container& container, std::size_t num_keys) SECTION("Non-inserted keys should output sentinels.") { - auto const [probed_end, matched_end] = container.retrieve_outer( - query_keys.begin(), query_keys.end(), probed_keys.begin(), matched_keys.begin()); + auto const [probed_end, matched_end] = container.retrieve_outer(query_keys.begin(), + query_keys.end(), + container.key_eq(), + container.hash_function(), + probed_keys.begin(), + matched_keys.begin()); REQUIRE(static_cast(std::distance(probed_keys.begin(), probed_end)) == num_keys * 2ull); REQUIRE(static_cast(std::distance(matched_keys.begin(), matched_end)) == @@ -112,8 +116,12 @@ void test_outer(Container& container, std::size_t num_keys) SECTION("All inserted keys should be contained.") { - auto const [probed_end, matched_end] = container.retrieve_outer( - query_keys.begin(), query_keys.end(), probed_keys.begin(), matched_keys.begin()); + auto const [probed_end, matched_end] = container.retrieve_outer(query_keys.begin(), + query_keys.end(), + container.key_eq(), + container.hash_function(), + probed_keys.begin(), + matched_keys.begin()); thrust::sort_by_key( probed_keys.begin(), probed_end, matched_keys.begin(), thrust::less()); @@ -160,4 +168,4 @@ TEMPLATE_TEST_CASE_SIG( test_multiplicity(set, num_keys, 2); // each key occurs twice test_multiplicity(set, num_keys, 11); test_outer(set, num_keys); -} \ No newline at end of file +}