Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-44393: [C++][Compute] Swizzle vector functions #44394

Open
wants to merge 58 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
1e141d7
WIP
zanmato1984 Sep 29, 2024
216e217
WIP
zanmato1984 Sep 30, 2024
f3c73ea
Add permute function options
zanmato1984 Oct 2, 2024
be88f0c
WIP
zanmato1984 Oct 4, 2024
b445c36
Implementation done and basic tests
zanmato1984 Oct 6, 2024
3707fa6
Implement permute
zanmato1984 Oct 7, 2024
4e9d3a6
Reorg reverse_index
zanmato1984 Oct 10, 2024
c4c5c41
Fix API and doc
zanmato1984 Oct 10, 2024
78bb335
Fix API and doc
zanmato1984 Oct 10, 2024
38bcc5d
Merge remote-tracking branch 'origin/main' into vector-placement
zanmato1984 Oct 10, 2024
d88877a
Init docs
zanmato1984 Oct 10, 2024
cc6a0ef
Merge branch 'vector-permute' into vector-placement
zanmato1984 Oct 10, 2024
2bbf44b
Refine
zanmato1984 Oct 10, 2024
b31f9f2
Update docs
zanmato1984 Oct 10, 2024
b951348
Refine doc
zanmato1984 Oct 11, 2024
520b952
Add comments for the implementation
zanmato1984 Oct 11, 2024
b450f5e
Refine docs
zanmato1984 Oct 11, 2024
4ea1465
Fix uint64 overflow check
zanmato1984 Oct 11, 2024
cbdce2f
Reverse indices tests
zanmato1984 Oct 11, 2024
d2e118a
Forbit non-array-like argument
zanmato1984 Oct 11, 2024
7128a28
Fix permute option default
zanmato1984 Oct 11, 2024
034d3b7
Refine
zanmato1984 Oct 11, 2024
9f93e5c
WIP permute tests
zanmato1984 Oct 11, 2024
c320002
Refine tests
zanmato1984 Oct 12, 2024
3e438e8
More permute tests
zanmato1984 Oct 12, 2024
0811b2b
Add if-else tests using permute
zanmato1984 Oct 13, 2024
154ad95
Update some comments
zanmato1984 Oct 13, 2024
66d977a
Fix lint
zanmato1984 Oct 14, 2024
a4c292c
Merge remote-tracking branch 'origin/main' into vector-placement
zanmato1984 Oct 14, 2024
846039d
Update comment
zanmato1984 Oct 14, 2024
2f2ae47
Fix typo
zanmato1984 Oct 14, 2024
3af49a8
Typo
zanmato1984 Oct 14, 2024
944609c
Refine
zanmato1984 Oct 17, 2024
e132f0d
Update cpp/src/arrow/compute/kernels/vector_placement_test.cc
zanmato1984 Oct 31, 2024
220598b
Rename function category to swizzle
zanmato1984 Nov 4, 2024
c03f6e0
reverse_indices -> inverse_permutation
zanmato1984 Nov 4, 2024
705c7b2
output_length -> max_index
zanmato1984 Nov 4, 2024
9e9ccb0
Permute -> Scatter
zanmato1984 Nov 4, 2024
bd334fe
Fixing some renamings
zanmato1984 Nov 6, 2024
bbe328d
Update docs/source/cpp/compute.rst
zanmato1984 Dec 11, 2024
5680322
Merge main
zanmato1984 Dec 11, 2024
6cd7f40
Update cpp/src/arrow/compute/api_vector.h
zanmato1984 Dec 11, 2024
9364e4c
Update cpp/src/arrow/compute/api_vector.h
zanmato1984 Dec 11, 2024
18f1f32
Update cpp/src/arrow/compute/kernels/vector_swizzle_test.cc
zanmato1984 Dec 11, 2024
d27782b
Update cpp/src/arrow/compute/kernels/vector_swizzle_test.cc
zanmato1984 Dec 11, 2024
2d09a8e
Limit input/output type to signed integers
zanmato1984 Dec 11, 2024
b0c52e8
Make visit method public and remove friend
zanmato1984 Dec 11, 2024
9315ce5
Show no mercy to index out of bounds
zanmato1984 Dec 12, 2024
17b6b2b
Use type error instead of invalid
zanmato1984 Dec 12, 2024
57d2c3d
Remove errornous predict false
zanmato1984 Dec 12, 2024
55d4611
Avoid uninitialized data buf
zanmato1984 Dec 12, 2024
e8fed69
Coding convention of instantce variables
zanmato1984 Dec 12, 2024
3fbd35f
Optimize buffer initializing
zanmato1984 Dec 12, 2024
3859abf
Reduce typed tests
zanmato1984 Dec 12, 2024
d17b88b
Naming
zanmato1984 Dec 12, 2024
cf54327
Remove repetition of test cases
zanmato1984 Dec 12, 2024
53f3c33
Doc about output length
zanmato1984 Dec 12, 2024
688b4d0
Fix ci error
zanmato1984 Dec 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -771,13 +771,14 @@ if(ARROW_COMPUTE)
compute/kernels/scalar_validity.cc
compute/kernels/vector_array_sort.cc
compute/kernels/vector_cumulative_ops.cc
compute/kernels/vector_pairwise.cc
compute/kernels/vector_nested.cc
compute/kernels/vector_pairwise.cc
compute/kernels/vector_rank.cc
compute/kernels/vector_replace.cc
compute/kernels/vector_run_end_encode.cc
compute/kernels/vector_select_k.cc
compute/kernels/vector_sort.cc
compute/kernels/vector_swizzle.cc
compute/key_hash_internal.cc
compute/key_map_internal.cc
compute/light_array_internal.cc
Expand Down
33 changes: 33 additions & 0 deletions cpp/src/arrow/compute/api_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@ static auto kPairwiseOptionsType = GetFunctionOptionsType<PairwiseOptions>(
DataMember("periods", &PairwiseOptions::periods));
static auto kListFlattenOptionsType = GetFunctionOptionsType<ListFlattenOptions>(
DataMember("recursive", &ListFlattenOptions::recursive));
static auto kInversePermutationOptionsType =
GetFunctionOptionsType<InversePermutationOptions>(
DataMember("max_index", &InversePermutationOptions::max_index),
DataMember("output_type", &InversePermutationOptions::output_type));
static auto kScatterOptionsType = GetFunctionOptionsType<ScatterOptions>(
DataMember("max_index", &ScatterOptions::max_index));
} // namespace
} // namespace internal

Expand Down Expand Up @@ -230,6 +236,17 @@ ListFlattenOptions::ListFlattenOptions(bool recursive)
: FunctionOptions(internal::kListFlattenOptionsType), recursive(recursive) {}
constexpr char ListFlattenOptions::kTypeName[];

InversePermutationOptions::InversePermutationOptions(
int64_t max_index, std::shared_ptr<DataType> output_type)
: FunctionOptions(internal::kInversePermutationOptionsType),
max_index(max_index),
output_type(std::move(output_type)) {}
constexpr char InversePermutationOptions::kTypeName[];

ScatterOptions::ScatterOptions(int64_t max_index)
: FunctionOptions(internal::kScatterOptionsType), max_index(max_index) {}
constexpr char ScatterOptions::kTypeName[];

namespace internal {
void RegisterVectorOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kFilterOptionsType));
Expand All @@ -244,6 +261,8 @@ void RegisterVectorOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kRankOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kPairwiseOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kListFlattenOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kInversePermutationOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kScatterOptionsType));
}
} // namespace internal

Expand Down Expand Up @@ -429,5 +448,19 @@ Result<Datum> CumulativeMean(const Datum& values, const CumulativeOptions& optio
return CallFunction("cumulative_mean", {Datum(values)}, &options, ctx);
}

// ----------------------------------------------------------------------
// Swizzle functions

Result<Datum> InversePermutation(const Datum& indices,
const InversePermutationOptions& options,
ExecContext* ctx) {
return CallFunction("inverse_permutation", {indices}, &options, ctx);
}

Result<Datum> Scatter(const Datum& values, const Datum& indices,
const ScatterOptions& options, ExecContext* ctx) {
return CallFunction("scatter", {values, indices}, &options, ctx);
}

} // namespace compute
} // namespace arrow
81 changes: 81 additions & 0 deletions cpp/src/arrow/compute/api_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,40 @@ class ARROW_EXPORT ListFlattenOptions : public FunctionOptions {
bool recursive = false;
};

/// \brief Options for inverse_permutation function
class ARROW_EXPORT InversePermutationOptions : public FunctionOptions {
public:
explicit InversePermutationOptions(int64_t max_index = -1,
std::shared_ptr<DataType> output_type = NULLPTR);
static constexpr char const kTypeName[] = "InversePermutationOptions";
static InversePermutationOptions Defaults() { return InversePermutationOptions(); }

/// \brief The max value in the input indices to allow. The length of the function's
/// output will be this value plus 1. If negative, this value will be set to the length
/// of the input indices minus 1 and the length of the function's output will be the
/// length of the input indices.
int64_t max_index = -1;
/// \brief The type of the output inverse permutation. If null, the output will be of
/// the same type as the input indices, otherwise must be signed integer type. An
/// invalid error will be reported if this type is not able to store the length of the
/// input indices.
std::shared_ptr<DataType> output_type = NULLPTR;
zanmato1984 marked this conversation as resolved.
Show resolved Hide resolved
};

/// \brief Options for scatter function
class ARROW_EXPORT ScatterOptions : public FunctionOptions {
public:
explicit ScatterOptions(int64_t max_index = -1);
static constexpr char const kTypeName[] = "ScatterOptions";
static ScatterOptions Defaults() { return ScatterOptions(); }

/// \brief The max value in the input indices to allow. The length of the function's
/// output will be this value plus 1. If negative, this value will be set to the length
/// of the input indices minus 1 and the length of the function's output will be the
/// length of the input indices.
int64_t max_index = -1;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it clearer to name it something like output_length? (you would have to offset its definition by 1)
@felipecrv

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, it was exactly output_length until I modified it according to @felipecrv 's suggestion. See #44394 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally prefer output_length as it's more straightforward and saves a little implementation code actually. On the other hand, max_index contains more "semantic" though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahah. Then perhaps we can just update the doc to say that the output length will be max_index + 1 if this option is set?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Is updating here 53f3c33 (#44394) enough?

};

/// @}

/// \brief Filter with a boolean selection filter
Expand Down Expand Up @@ -705,5 +739,52 @@ Result<std::shared_ptr<Array>> PairwiseDiff(const Array& array,
bool check_overflow = false,
ExecContext* ctx = NULLPTR);

/// \brief Return the inverse permutation of the given indices.
///
/// For indices[i] = x, inverse_permutation[x] = i. And inverse_permutation[x] = null if x
/// does not appear in the input indices. Indices must be in the range of [0, max_index],
/// or null, which will be ignored. If multiple indices point to the same value, the last
/// one is used.
///
/// For example, with indices = [null, 0, 3, 2, 4, 1, 1], the inverse permutation is
/// [1, 6, 3, 2, 4, null, null] if max_index = 6.
///
/// \param[in] indices array-like indices
/// \param[in] options configures the max index and the output type
/// \param[in] ctx the function execution context, optional
/// \return the resulting inverse permutation
///
/// \since 19.0.0
/// \note API not yet finalized
ARROW_EXPORT
Result<Datum> InversePermutation(
const Datum& indices,
const InversePermutationOptions& options = InversePermutationOptions::Defaults(),
ExecContext* ctx = NULLPTR);

/// \brief Scatter the values into specified positions according to the indices.
///
/// For indices[i] = x, output[x] = values[i]. And output[x] = null if x does not appear
/// in the input indices. Indices must be in the range of [0, max_index], or null, in
/// which case the corresponding value will be ignored. If multiple indices point to the
/// same value, the last one is used.
///
/// For example, with values = [a, b, c, d, e, f, g] and indices = [null, 0,
/// 3, 2, 4, 1, 1], the output is
/// [b, g, d, c, e, null, null] if max_index = 6.
///
/// \param[in] values datum to scatter
/// \param[in] indices array-like indices
/// \param[in] options configures the max index of to scatter
/// \param[in] ctx the function execution context, optional
/// \return the resulting datum
///
/// \since 19.0.0
/// \note API not yet finalized
ARROW_EXPORT
Result<Datum> Scatter(const Datum& values, const Datum& indices,
const ScatterOptions& options = ScatterOptions::Defaults(),
ExecContext* ctx = NULLPTR);

} // namespace compute
} // namespace arrow
4 changes: 4 additions & 0 deletions cpp/src/arrow/compute/function_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ TEST(FunctionOptions, Equality) {
options.emplace_back(new SelectKOptions(5, {{SortKey("key", SortOrder::Ascending)}}));
options.emplace_back(new Utf8NormalizeOptions());
options.emplace_back(new Utf8NormalizeOptions(Utf8NormalizeOptions::NFD));
options.emplace_back(
new InversePermutationOptions(/*max_index=*/42, /*output_type=*/int32()));
options.emplace_back(new ScatterOptions());
options.emplace_back(new ScatterOptions(/*max_index=*/42));

for (size_t i = 0; i < options.size(); i++) {
const size_t prev_i = i == 0 ? options.size() - 1 : i - 1;
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/arrow/compute/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ add_arrow_compute_test(vector_selection_test
EXTRA_LINK_LIBS
arrow_compute_kernels_testing)

add_arrow_compute_test(vector_swizzle_test
SOURCES
vector_swizzle_test.cc
EXTRA_LINK_LIBS
arrow_compute_kernels_testing)

add_arrow_benchmark(vector_hash_benchmark PREFIX "arrow-compute")
add_arrow_benchmark(vector_sort_benchmark PREFIX "arrow-compute")
add_arrow_benchmark(vector_partition_benchmark PREFIX "arrow-compute")
Expand Down
5 changes: 3 additions & 2 deletions cpp/src/arrow/compute/kernels/codegen_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1037,8 +1037,9 @@ ArrayKernelExec GenerateFloatingPoint(detail::GetTypeId get_id) {
// Generate a kernel given a templated functor for integer types
//
// See "Numeric" above for description of the generator functor
template <template <typename...> class Generator, typename Type0, typename... Args>
ArrayKernelExec GenerateInteger(detail::GetTypeId get_id) {
template <template <typename...> class Generator, typename Type0,
typename KernelType = ArrayKernelExec, typename... Args>
KernelType GenerateInteger(detail::GetTypeId get_id) {
zanmato1984 marked this conversation as resolved.
Show resolved Hide resolved
switch (get_id.id) {
case Type::INT8:
return Generator<Type0, Int8Type, Args...>::Exec;
Expand Down
Loading
Loading