From ccf36967fd5b16c93fd32c86892b4f60835fe2da Mon Sep 17 00:00:00 2001 From: Rok Mihevc Date: Mon, 4 Sep 2023 06:05:11 +0200 Subject: [PATCH] Add FixedShapeTensorScalarType --- python/pyarrow/array.pxi | 20 ++++++++++++++++++++ python/pyarrow/tests/test_extension_type.py | 3 +++ 2 files changed, 23 insertions(+) diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index e26b1ad3291b5..530f578ef5993 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -3325,6 +3325,23 @@ cdef class ExtensionArray(Array): return result +class FixedShapeTensorScalarType(ExtensionScalar): + """ + Concrete class for fixed shape tensor extension scalar type. + """ + def to_numpy_ndarray(self): + """ + Convert fixed shape tensor extension scalar to a numpy array (with dim+1). + """ + self.to_numpy_ndarray() + + def as_py(self): + """ + Convert fixed shape tensor extension scalar to a python list. + """ + return self.to_numpy_ndarray().tolist() + + class FixedShapeTensorArray(ExtensionArray): """ Concrete class for fixed shape tensor extension arrays. @@ -3433,6 +3450,9 @@ class FixedShapeTensorArray(ExtensionArray): FixedSizeListArray.from_arrays(np.ravel(obj, order='C'), size) ) + def __arrow_ext_scalar_class__(self): + return FixedShapeTensorScalarType + cdef dict _array_classes = { _Type_NA: NullArray, diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index 1eb7d5fa76188..703a638078de3 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -1247,6 +1247,9 @@ def test_tensor_class_methods(): with pytest.raises(ValueError, match="non-permuted tensors"): arr.to_numpy_ndarray() + for i in range(expected.shape[0]): + assert arr[i].to_numpy_ndarray() == expected[i] + @pytest.mark.parametrize("tensor_type", ( pa.fixed_shape_tensor(pa.int8(), [2, 2, 3]),