Skip to content

Commit

Permalink
experimental comparators only
Browse files Browse the repository at this point in the history
  • Loading branch information
divyegala committed Aug 2, 2023
1 parent c61d2af commit 7bc277b
Showing 1 changed file with 37 additions and 95 deletions.
132 changes: 37 additions & 95 deletions cpp/src/lists/contains.cu
Original file line number Diff line number Diff line change
Expand Up @@ -105,60 +105,19 @@ __device__ auto element_index_pair_iter(size_type const size)
}
}

/**
* @brief Functor to perform searching for index of a key element in a given list, specialized
* for non-nested types.
*/
struct search_list_non_nested_types_fn {
duplicate_find_option const find_option;

template <typename Element, CUDF_ENABLE_IF(is_supported_non_nested_type<Element>())>
__device__ size_type operator()(list_device_view const list,
thrust::optional<Element> const key_opt) const
{
// A null list or null key will result in a null output row.
if (list.is_null() || !key_opt) { return NULL_SENTINEL; }

return find_option == duplicate_find_option::FIND_FIRST
? search_list<true, Element>(list, *key_opt)
: search_list<false, Element>(list, *key_opt);
}

template <typename Element, CUDF_ENABLE_IF(!is_supported_non_nested_type<Element>())>
__device__ size_type operator()(list_device_view const, thrust::optional<Element> const) const
{
CUDF_UNREACHABLE("Unsupported type.");
}

private:
template <bool forward, typename Element, CUDF_ENABLE_IF(is_supported_non_nested_type<Element>())>
static __device__ inline size_type search_list(list_device_view const list,
Element const search_key)
{
auto const [begin, end] = element_index_pair_iter<forward>(list.size());
auto const found_iter =
thrust::find_if(thrust::seq, begin, end, [=] __device__(auto const idx) {
return !list.is_null(idx) &&
cudf::equality_compare(list.template element<Element>(idx), search_key);
});
// If the key is found, return its found position in the list from `found_iter`.
return found_iter == end ? NOT_FOUND_SENTINEL : *found_iter;
}
};

/**
* @brief Functor to perform searching for index of a key element in a given list, specialized
* for nested types.
*/
template <typename KeyValidityIter, typename EqComparator>
struct search_list_nested_types_fn {
struct search_list {
duplicate_find_option const find_option;
KeyValidityIter const key_validity_iter;
EqComparator const d_comp;

search_list_nested_types_fn(duplicate_find_option const find_option,
KeyValidityIter const key_validity_iter,
EqComparator const& d_comp)
search_list(duplicate_find_option const find_option,
KeyValidityIter const key_validity_iter,
EqComparator const& d_comp)
: find_option(find_option), key_validity_iter(key_validity_iter), d_comp(d_comp)
{
}
Expand All @@ -168,13 +127,13 @@ struct search_list_nested_types_fn {
// A null list or null key will result in a null output row.
if (list.is_null() || !key_validity_iter[list.row_index()]) { return NULL_SENTINEL; }

return find_option == duplicate_find_option::FIND_FIRST ? search_list<true>(list)
: search_list<false>(list);
return find_option == duplicate_find_option::FIND_FIRST ? search_list_op<true>(list)
: search_list_op<false>(list);
}

private:
template <bool forward>
__device__ inline size_type search_list(list_device_view const list) const
__device__ inline size_type search_list_op(list_device_view const list) const
{
using cudf::experimental::row::lhs_index_type;
using cudf::experimental::row::rhs_index_type;
Expand All @@ -190,57 +149,45 @@ struct search_list_nested_types_fn {
}
};

/**
* @brief Function to search for key element(s) in the corresponding rows of a lists column,
* specialized for non-nested types.
*/
template <typename Element, typename InputIterator, typename OutputIterator>
void index_of_non_nested_types(InputIterator input_it,
size_type num_rows,
OutputIterator output_it,
column_view const& search_keys,
bool search_keys_have_nulls,
duplicate_find_option find_option,
rmm::cuda_stream_view stream)
{
auto const keys_cdv_ptr = column_device_view::create(search_keys, stream);
auto const keys_iter = cudf::detail::make_optional_iterator<Element>(
*keys_cdv_ptr, nullate::DYNAMIC{search_keys_have_nulls});
thrust::transform(rmm::exec_policy(stream),
input_it,
input_it + num_rows,
keys_iter,
output_it,
search_list_non_nested_types_fn{find_option});
}

/**
* @brief Function to search for index of key element(s) in the corresponding rows of a lists
* column, specialized for nested types.
*/
template <typename InputIterator, typename OutputIterator>
void index_of_nested_types(InputIterator input_it,
size_type num_rows,
OutputIterator output_it,
column_view const& child,
column_view const& search_keys,
duplicate_find_option find_option,
rmm::cuda_stream_view stream)
void index_of(InputIterator input_it,
size_type num_rows,
OutputIterator output_it,
column_view const& child,
column_view const& search_keys,
duplicate_find_option find_option,
rmm::cuda_stream_view stream)
{
auto const keys_tview = cudf::table_view{{search_keys}};
auto const child_tview = table_view{{child}};
auto const has_nulls = has_nested_nulls(child_tview) || has_nested_nulls(keys_tview);
auto const comparator =
cudf::experimental::row::equality::two_table_comparator(child_tview, keys_tview, stream);
auto const d_comp = comparator.equal_to<true>(nullate::DYNAMIC{has_nulls});

auto const keys_dv_ptr = column_device_view::create(search_keys, stream);
auto const key_validity_iter = cudf::detail::make_validity_iterator<true>(*keys_dv_ptr);
thrust::transform(rmm::exec_policy(stream),
input_it,
input_it + num_rows,
output_it,
search_list_nested_types_fn{find_option, key_validity_iter, d_comp});
auto do_search = [&](auto const d_comp) {
auto const keys_dv_ptr = column_device_view::create(search_keys, stream);
auto const key_validity_iter = cudf::detail::make_validity_iterator<true>(*keys_dv_ptr);
thrust::transform(rmm::exec_policy(stream),
input_it,
input_it + num_rows,
output_it,
search_list{find_option, key_validity_iter, d_comp});
};

if (cudf::detail::has_nested_columns(child_tview) or
cudf::detail::has_nested_columns(keys_tview)) {
auto const comparator =
cudf::experimental::row::equality::two_table_comparator(child_tview, keys_tview, stream);
auto const d_comp = comparator.equal_to<true>(nullate::DYNAMIC{has_nulls});
do_search(d_comp);
} else {
auto const comparator =
cudf::experimental::row::equality::two_table_comparator(child_tview, keys_tview, stream);
auto const d_comp = comparator.equal_to<false>(nullate::DYNAMIC{has_nulls});
do_search(d_comp);
}
}

/**
Expand Down Expand Up @@ -283,12 +230,7 @@ struct dispatch_index_of {
data_type{type_to_id<size_type>()}, num_rows, cudf::mask_state::UNALLOCATED, stream, mr);
auto const output_it = out_positions->mutable_view().template begin<size_type>();

if constexpr (not cudf::is_nested<Element>()) {
index_of_non_nested_types<Element>(
input_it, num_rows, output_it, search_keys, search_keys_have_nulls, find_option, stream);
} else { // list + struct
index_of_nested_types(input_it, num_rows, output_it, child, search_keys, find_option, stream);
}
index_of(input_it, num_rows, output_it, child, search_keys, find_option, stream);

if (search_keys_have_nulls || lists.has_nulls()) {
auto [null_mask, null_count] = cudf::detail::valid_if(
Expand Down

0 comments on commit 7bc277b

Please sign in to comment.