Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

joining tables with ExtensionArrays #44473

Open
NellyWhads opened this issue Oct 18, 2024 · 0 comments
Open

joining tables with ExtensionArrays #44473

NellyWhads opened this issue Oct 18, 2024 · 0 comments

Comments

@NellyWhads
Copy link

NellyWhads commented Oct 18, 2024

Describe the enhancement requested

I'm looking for documentation on how to implement an ExtensionArray which supports join functionality.

Particularly, I'd like to join a table which includes a FixedShapeTensorArray column with another table.

Here's a simple example which does not work.

import numpy as np
import pyarrow as pa

# First dim is the batch dim
tensors = np.arange(3 * 10 * 10).reshape((3, 10, 10)).astype(np.uint8)
tensor_array = pa.FixedShapeTensorArray.from_numpy_ndarray(tensors)
ids = pa.array([1,2,3], type=pa.uint8())
table = pa.Table.from_arrays([ids, tensor_array], names=["id", "tensor"])
print(table.schema)

classes = pa.array(["one", "two", "three"], type=pa.string())
table_2 = pa.Table.from_arrays([ids, classes], names=["id", "name"])
print(table_2.schema)

table.join(table_2, keys=["id"], join_type="full outer")

This raises the error

---------------------------------------------------------------------------
ArrowInvalid                              Traceback (most recent call last)
Cell In[42], [line 1](vscode-notebook-cell:?execution_count=42&line=1)
----> [1](vscode-notebook-cell:?execution_count=42&line=1) table.join(table_2, keys=["id"], join_type="full outer")

File ~/.pyenv/versions/next_gen_data_38/lib/python3.8/site-packages/pyarrow/table.pxi:5570, in pyarrow.lib.Table.join()

File ~/.pyenv/versions/next_gen_data_38/lib/python3.8/site-packages/pyarrow/acero.py:247, in _perform_join(join_type, left_operand, left_keys, right_operand, right_keys, left_suffix, right_suffix, use_threads, coalesce_keys, output_type)
    [242](https://file+.vscode-resource.vscode-cdn.net/Users/nw/workspaces/main/next_gen_data/next_gen_data/~/.pyenv/versions/next_gen_data_38/lib/python3.8/site-packages/pyarrow/acero.py:242)     projection = Declaration(
    [243](https://file+.vscode-resource.vscode-cdn.net/Users/nw/workspaces/main/next_gen_data/next_gen_data/~/.pyenv/versions/next_gen_data_38/lib/python3.8/site-packages/pyarrow/acero.py:243)         "project", ProjectNodeOptions(projections, projected_col_names)
    [244](https://file+.vscode-resource.vscode-cdn.net/Users/nw/workspaces/main/next_gen_data/next_gen_data/~/.pyenv/versions/next_gen_data_38/lib/python3.8/site-packages/pyarrow/acero.py:244)     )
    [245](https://file+.vscode-resource.vscode-cdn.net/Users/nw/workspaces/main/next_gen_data/next_gen_data/~/.pyenv/versions/next_gen_data_38/lib/python3.8/site-packages/pyarrow/acero.py:245)     decl = Declaration.from_sequence([decl, projection])
--> [247](https://file+.vscode-resource.vscode-cdn.net/Users/nw/workspaces/main/next_gen_data/next_gen_data/~/.pyenv/versions/next_gen_data_38/lib/python3.8/site-packages/pyarrow/acero.py:247) result_table = decl.to_table(use_threads=use_threads)
    [249](https://file+.vscode-resource.vscode-cdn.net/Users/nw/workspaces/main/next_gen_data/next_gen_data/~/.pyenv/versions/next_gen_data_38/lib/python3.8/site-packages/pyarrow/acero.py:249) if output_type == Table:
    [250](https://file+.vscode-resource.vscode-cdn.net/Users/nw/workspaces/main/pytorc/projects/next_gen_data/next_gen_data/~/.pyenv/versions/next_gen_data_38/lib/python3.8/site-packages/pyarrow/acero.py:250)     return result_table

File ~/.pyenv/versions/next_gen_data_38/lib/python3.8/site-packages/pyarrow/_acero.pyx:590, in pyarrow._acero.Declaration.to_table()

File ~/.pyenv/versions/next_gen_data_38/lib/python3.8/site-packages/pyarrow/error.pxi:155, in pyarrow.lib.pyarrow_internal_check_status()

File ~/.pyenv/versions/next_gen_data_38/lib/python3.8/site-packages/pyarrow/error.pxi:92, in pyarrow.lib.check_status()

ArrowInvalid: Data type extension<arrow.fixed_shape_tensor[value_type=uint8, shape=[10,10], permutation=[0,1]]> is not supported in join non-key field tensor

How can I make this work? The individual tensors I want to store are rather small (single-digit-dimensions), but the join may lead to list aggregation of a few hundred rows.

I've tagged this as a python question because I don't know what level of API needs to be adjusted to add this functionality.

Component(s)

Python

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant