diff --git a/cpp/src/lists/contains.cu b/cpp/src/lists/contains.cu index a3293e36825..9d39f2f9a90 100644 --- a/cpp/src/lists/contains.cu +++ b/cpp/src/lists/contains.cu @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -154,16 +155,11 @@ struct search_list_nested_types_fn { duplicate_find_option const find_option; KeyValidityIter const key_validity_iter; EqComparator const d_comp; - bool const search_key_is_scalar; search_list_nested_types_fn(duplicate_find_option const find_option, KeyValidityIter const key_validity_iter, - EqComparator const& d_comp, - bool search_key_is_scalar) - : find_option(find_option), - key_validity_iter(key_validity_iter), - d_comp(d_comp), - search_key_is_scalar(search_key_is_scalar) + EqComparator const& d_comp) + : find_option(find_option), key_validity_iter(key_validity_iter), d_comp(d_comp) { } @@ -186,9 +182,8 @@ struct search_list_nested_types_fn { auto const [begin, end] = element_index_pair_iter(list.size()); auto const found_iter = thrust::find_if(thrust::seq, begin, end, [=] __device__(auto const idx) { - return !list.is_null(idx) && - d_comp(static_cast(list.element_offset(idx)), - static_cast(search_key_is_scalar ? 0 : list.row_index())); + return !list.is_null(idx) && d_comp(static_cast(list.element_offset(idx)), + static_cast(list.row_index())); }); // If the key is found, return its found position in the list from `found_iter`. return found_iter == end ? NOT_FOUND_SENTINEL : *found_iter; @@ -199,93 +194,53 @@ struct search_list_nested_types_fn { * @brief Function to search for key element(s) in the corresponding rows of a lists column, * specialized for non-nested types. */ -template +template void index_of_non_nested_types(InputIterator input_it, size_type num_rows, OutputIterator output_it, - SearchKeyType const& search_keys, + column_view const& search_keys, bool search_keys_have_nulls, duplicate_find_option find_option, rmm::cuda_stream_view stream) { - auto const do_search = [=](auto const keys_iter) { - thrust::transform(rmm::exec_policy(stream), - input_it, - input_it + num_rows, - keys_iter, - output_it, - search_list_non_nested_types_fn{find_option}); - }; - - if constexpr (search_key_is_scalar) { - auto const keys_iter = cudf::detail::make_optional_iterator( - search_keys, nullate::DYNAMIC{search_keys_have_nulls}); - do_search(keys_iter); - } else { - auto const keys_cdv_ptr = column_device_view::create(search_keys, stream); - auto const keys_iter = cudf::detail::make_optional_iterator( - *keys_cdv_ptr, nullate::DYNAMIC{search_keys_have_nulls}); - do_search(keys_iter); - } + auto const keys_cdv_ptr = column_device_view::create(search_keys, stream); + auto const keys_iter = cudf::detail::make_optional_iterator( + *keys_cdv_ptr, nullate::DYNAMIC{search_keys_have_nulls}); + thrust::transform(rmm::exec_policy(stream), + input_it, + input_it + num_rows, + keys_iter, + output_it, + search_list_non_nested_types_fn{find_option}); } /** * @brief Function to search for index of key element(s) in the corresponding rows of a lists * column, specialized for nested types. */ -template +template void index_of_nested_types(InputIterator input_it, size_type num_rows, OutputIterator output_it, column_view const& child, - SearchKeyType const& search_keys, + column_view const& search_keys, duplicate_find_option find_option, rmm::cuda_stream_view stream) { - // Create a `table_view` from the search key(s). - // If the input search key is a (nested type) scalar, a new column is materialized from that - // scalar before a `table_view` is generated from it. As such, the new created column will also be - // returned to keep the result `table_view` valid. - [[maybe_unused]] auto const [keys_tview, unused_column] = - [&]() -> std::pair> { - if constexpr (search_key_is_scalar) { - auto tmp_column = make_column_from_scalar(search_keys, 1, stream); - return {table_view{{tmp_column->view()}}, std::move(tmp_column)}; - } else { - return {table_view{{search_keys}}, nullptr}; - } - }(); - + auto const keys_tview = cudf::table_view{{search_keys}}; auto const child_tview = table_view{{child}}; auto const has_nulls = has_nested_nulls(child_tview) || has_nested_nulls(keys_tview); auto const comparator = cudf::experimental::row::equality::two_table_comparator(child_tview, keys_tview, stream); auto const d_comp = comparator.equal_to(nullate::DYNAMIC{has_nulls}); - auto const do_search = [=](auto const key_validity_iter) { - thrust::transform( - rmm::exec_policy(stream), - input_it, - input_it + num_rows, - output_it, - search_list_nested_types_fn{find_option, key_validity_iter, d_comp, search_key_is_scalar}); - }; - - if constexpr (search_key_is_scalar) { - auto const key_validity_iter = cudf::detail::make_validity_iterator(search_keys); - do_search(key_validity_iter); - } else { - auto const keys_dv_ptr = column_device_view::create(search_keys, stream); - auto const key_validity_iter = cudf::detail::make_validity_iterator(*keys_dv_ptr); - do_search(key_validity_iter); - } + auto const keys_dv_ptr = column_device_view::create(search_keys, stream); + auto const key_validity_iter = cudf::detail::make_validity_iterator(*keys_dv_ptr); + thrust::transform(rmm::exec_policy(stream), + input_it, + input_it + num_rows, + output_it, + search_list_nested_types_fn{find_option, key_validity_iter, d_comp}); } /** @@ -295,10 +250,10 @@ void index_of_nested_types(InputIterator input_it, struct dispatch_index_of { // SFINAE with conditional return type because we need to support device lambda in this function. // This is required due to a limitation of nvcc. - template + template std::enable_if_t(), std::unique_ptr> operator()( lists_column_view const& lists, - SearchKeyType const& search_keys, + column_view const& search_keys, duplicate_find_option find_option, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) const @@ -313,27 +268,10 @@ struct dispatch_index_of { cudf::data_type_error); CUDF_EXPECTS(search_keys.type().id() != type_id::EMPTY, "Type cannot be empty."); - auto constexpr search_key_is_scalar = std::is_same_v; - auto const search_keys_have_nulls = [&search_keys, stream] { - if constexpr (search_key_is_scalar) { - return !search_keys.is_valid(stream); - } else { - return search_keys.has_nulls(); - } - }(); + auto const search_keys_have_nulls = search_keys.has_nulls(); auto const num_rows = lists.size(); - if (search_key_is_scalar && search_keys_have_nulls) { - // If the scalar key is invalid/null, the entire output column will be all nulls. - return make_numeric_column(data_type{cudf::type_to_id()}, - num_rows, - cudf::create_null_mask(num_rows, mask_state::ALL_NULL, mr), - num_rows, - stream, - mr); - } - auto const lists_cdv_ptr = column_device_view::create(lists.parent(), stream); auto const input_it = cudf::detail::make_counting_transform_iterator( size_type{0}, @@ -346,11 +284,10 @@ struct dispatch_index_of { auto const output_it = out_positions->mutable_view().template begin(); if constexpr (not cudf::is_nested()) { - index_of_non_nested_types( + index_of_non_nested_types( input_it, num_rows, output_it, search_keys, search_keys_have_nulls, find_option, stream); } else { // list + struct - index_of_nested_types( - input_it, num_rows, output_it, child, search_keys, find_option, stream); + index_of_nested_types(input_it, num_rows, output_it, child, search_keys, find_option, stream); } if (search_keys_have_nulls || lists.has_nulls()) { @@ -414,8 +351,21 @@ std::unique_ptr index_of(lists_column_view const& lists, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { - return cudf::type_dispatcher( - search_key.type(), dispatch_index_of{}, lists, search_key, find_option, stream, mr); + if (!search_key.is_valid(stream)) { + return make_numeric_column(data_type{cudf::type_to_id()}, + lists.size(), + cudf::create_null_mask(lists.size(), mask_state::ALL_NULL, mr), + lists.size(), + stream, + mr); + } + if (lists.size() == 0) { + return make_numeric_column( + data_type{type_to_id()}, 0, cudf::mask_state::UNALLOCATED, stream, mr); + } + + auto search_key_col = cudf::make_column_from_scalar(search_key, lists.size(), stream, mr); + return index_of(lists, search_key_col->view(), find_option, stream, mr); } std::unique_ptr index_of(lists_column_view const& lists,