Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce lists::contains dispatches for scalars #13805

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 42 additions & 83 deletions cpp/src/lists/contains.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <cudf/detail/iterator.cuh>
#include <cudf/detail/valid_if.cuh>
#include <cudf/lists/detail/contains.hpp>
#include <cudf/lists/detail/lists_column_factories.hpp>
#include <cudf/lists/list_device_view.cuh>
#include <cudf/lists/lists_column_device_view.cuh>
#include <cudf/lists/lists_column_view.hpp>
Expand Down Expand Up @@ -154,16 +155,11 @@ struct search_list_nested_types_fn {
duplicate_find_option const find_option;
KeyValidityIter const key_validity_iter;
EqComparator const d_comp;
bool const search_key_is_scalar;

search_list_nested_types_fn(duplicate_find_option const find_option,
KeyValidityIter const key_validity_iter,
EqComparator const& d_comp,
bool search_key_is_scalar)
: find_option(find_option),
key_validity_iter(key_validity_iter),
d_comp(d_comp),
search_key_is_scalar(search_key_is_scalar)
EqComparator const& d_comp)
: find_option(find_option), key_validity_iter(key_validity_iter), d_comp(d_comp)
{
}

Expand All @@ -186,9 +182,8 @@ struct search_list_nested_types_fn {
auto const [begin, end] = element_index_pair_iter<forward>(list.size());
auto const found_iter =
thrust::find_if(thrust::seq, begin, end, [=] __device__(auto const idx) {
return !list.is_null(idx) &&
d_comp(static_cast<lhs_index_type>(list.element_offset(idx)),
static_cast<rhs_index_type>(search_key_is_scalar ? 0 : list.row_index()));
return !list.is_null(idx) && d_comp(static_cast<lhs_index_type>(list.element_offset(idx)),
static_cast<rhs_index_type>(list.row_index()));
});
// If the key is found, return its found position in the list from `found_iter`.
return found_iter == end ? NOT_FOUND_SENTINEL : *found_iter;
Expand All @@ -199,15 +194,11 @@ struct search_list_nested_types_fn {
* @brief Function to search for key element(s) in the corresponding rows of a lists column,
* specialized for non-nested types.
*/
template <bool search_key_is_scalar,
typename Element,
typename InputIterator,
typename OutputIterator,
typename SearchKeyType>
template <typename Element, typename InputIterator, typename OutputIterator>
void index_of_non_nested_types(InputIterator input_it,
size_type num_rows,
OutputIterator output_it,
SearchKeyType const& search_keys,
column_view const& search_keys,
bool search_keys_have_nulls,
duplicate_find_option find_option,
rmm::cuda_stream_view stream)
Expand All @@ -221,47 +212,26 @@ void index_of_non_nested_types(InputIterator input_it,
search_list_non_nested_types_fn{find_option});
};

if constexpr (search_key_is_scalar) {
auto const keys_iter = cudf::detail::make_optional_iterator<Element>(
search_keys, nullate::DYNAMIC{search_keys_have_nulls});
do_search(keys_iter);
} else {
auto const keys_cdv_ptr = column_device_view::create(search_keys, stream);
auto const keys_iter = cudf::detail::make_optional_iterator<Element>(
*keys_cdv_ptr, nullate::DYNAMIC{search_keys_have_nulls});
do_search(keys_iter);
}
auto const keys_cdv_ptr = column_device_view::create(search_keys, stream);
auto const keys_iter = cudf::detail::make_optional_iterator<Element>(
*keys_cdv_ptr, nullate::DYNAMIC{search_keys_have_nulls});
do_search(keys_iter);
divyegala marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* @brief Function to search for index of key element(s) in the corresponding rows of a lists
* column, specialized for nested types.
*/
template <bool search_key_is_scalar,
typename InputIterator,
typename OutputIterator,
typename SearchKeyType>
template <typename InputIterator, typename OutputIterator>
void index_of_nested_types(InputIterator input_it,
size_type num_rows,
OutputIterator output_it,
column_view const& child,
SearchKeyType const& search_keys,
column_view const& search_keys,
duplicate_find_option find_option,
rmm::cuda_stream_view stream)
{
// Create a `table_view` from the search key(s).
// If the input search key is a (nested type) scalar, a new column is materialized from that
// scalar before a `table_view` is generated from it. As such, the new created column will also be
// returned to keep the result `table_view` valid.
[[maybe_unused]] auto const [keys_tview, unused_column] =
[&]() -> std::pair<table_view, std::unique_ptr<column>> {
if constexpr (search_key_is_scalar) {
auto tmp_column = make_column_from_scalar(search_keys, 1, stream);
return {table_view{{tmp_column->view()}}, std::move(tmp_column)};
} else {
return {table_view{{search_keys}}, nullptr};
}
}();
auto const keys_tview = cudf::table_view{{search_keys}};

divyegala marked this conversation as resolved.
Show resolved Hide resolved
auto const child_tview = table_view{{child}};
auto const has_nulls = has_nested_nulls(child_tview) || has_nested_nulls(keys_tview);
Expand All @@ -270,22 +240,16 @@ void index_of_nested_types(InputIterator input_it,
auto const d_comp = comparator.equal_to<true>(nullate::DYNAMIC{has_nulls});

auto const do_search = [=](auto const key_validity_iter) {
thrust::transform(
rmm::exec_policy(stream),
input_it,
input_it + num_rows,
output_it,
search_list_nested_types_fn{find_option, key_validity_iter, d_comp, search_key_is_scalar});
thrust::transform(rmm::exec_policy(stream),
input_it,
input_it + num_rows,
output_it,
search_list_nested_types_fn{find_option, key_validity_iter, d_comp});
};

if constexpr (search_key_is_scalar) {
auto const key_validity_iter = cudf::detail::make_validity_iterator<true>(search_keys);
do_search(key_validity_iter);
} else {
auto const keys_dv_ptr = column_device_view::create(search_keys, stream);
auto const key_validity_iter = cudf::detail::make_validity_iterator<true>(*keys_dv_ptr);
do_search(key_validity_iter);
}
auto const keys_dv_ptr = column_device_view::create(search_keys, stream);
auto const key_validity_iter = cudf::detail::make_validity_iterator<true>(*keys_dv_ptr);
do_search(key_validity_iter);
divyegala marked this conversation as resolved.
Show resolved Hide resolved
}

/**
Expand All @@ -295,10 +259,10 @@ void index_of_nested_types(InputIterator input_it,
struct dispatch_index_of {
// SFINAE with conditional return type because we need to support device lambda in this function.
// This is required due to a limitation of nvcc.
template <typename Element, typename SearchKeyType>
template <typename Element>
std::enable_if_t<is_supported_type<Element>(), std::unique_ptr<column>> operator()(
lists_column_view const& lists,
SearchKeyType const& search_keys,
column_view const& search_keys,
duplicate_find_option find_option,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr) const
Expand All @@ -313,27 +277,10 @@ struct dispatch_index_of {
cudf::data_type_error);
CUDF_EXPECTS(search_keys.type().id() != type_id::EMPTY, "Type cannot be empty.");

auto constexpr search_key_is_scalar = std::is_same_v<SearchKeyType, cudf::scalar>;
auto const search_keys_have_nulls = [&search_keys, stream] {
if constexpr (search_key_is_scalar) {
return !search_keys.is_valid(stream);
} else {
return search_keys.has_nulls();
}
}();
auto const search_keys_have_nulls = search_keys.has_nulls();

auto const num_rows = lists.size();

if (search_key_is_scalar && search_keys_have_nulls) {
// If the scalar key is invalid/null, the entire output column will be all nulls.
return make_numeric_column(data_type{cudf::type_to_id<size_type>()},
num_rows,
cudf::create_null_mask(num_rows, mask_state::ALL_NULL, mr),
num_rows,
stream,
mr);
}

auto const lists_cdv_ptr = column_device_view::create(lists.parent(), stream);
auto const input_it = cudf::detail::make_counting_transform_iterator(
size_type{0},
Expand All @@ -346,11 +293,10 @@ struct dispatch_index_of {
auto const output_it = out_positions->mutable_view().template begin<size_type>();

if constexpr (not cudf::is_nested<Element>()) {
index_of_non_nested_types<search_key_is_scalar, Element>(
index_of_non_nested_types<Element>(
input_it, num_rows, output_it, search_keys, search_keys_have_nulls, find_option, stream);
} else { // list + struct
index_of_nested_types<search_key_is_scalar>(
input_it, num_rows, output_it, child, search_keys, find_option, stream);
index_of_nested_types(input_it, num_rows, output_it, child, search_keys, find_option, stream);
}

if (search_keys_have_nulls || lists.has_nulls()) {
Expand Down Expand Up @@ -414,8 +360,21 @@ std::unique_ptr<column> index_of(lists_column_view const& lists,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
return cudf::type_dispatcher(
search_key.type(), dispatch_index_of{}, lists, search_key, find_option, stream, mr);
if (!search_key.is_valid(stream)) {
return make_numeric_column(data_type{cudf::type_to_id<size_type>()},
lists.size(),
cudf::create_null_mask(lists.size(), mask_state::ALL_NULL, mr),
lists.size(),
stream,
mr);
}
if (lists.size() == 0) {
return make_numeric_column(
data_type{type_to_id<size_type>()}, 0, cudf::mask_state::UNALLOCATED, stream, mr);
}

auto search_key_col = cudf::make_column_from_scalar(search_key, lists.size(), stream, mr);
return index_of(lists, search_key_col->view(), find_option, stream, mr);
}

std::unique_ptr<column> index_of(lists_column_view const& lists,
Expand Down
Loading