Skip to content

Commit

Permalink
Add FixedShapeTensorScalarType
Browse files Browse the repository at this point in the history
  • Loading branch information
rok committed Sep 4, 2023
1 parent 3b14c74 commit ccf3696
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
20 changes: 20 additions & 0 deletions python/pyarrow/array.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -3325,6 +3325,23 @@ cdef class ExtensionArray(Array):
return result


class FixedShapeTensorScalarType(ExtensionScalar):
"""
Concrete class for fixed shape tensor extension scalar type.
"""
def to_numpy_ndarray(self):
"""
Convert fixed shape tensor extension scalar to a numpy array (with dim+1).
"""
self.to_numpy_ndarray()

def as_py(self):
"""
Convert fixed shape tensor extension scalar to a python list.
"""
return self.to_numpy_ndarray().tolist()


class FixedShapeTensorArray(ExtensionArray):
"""
Concrete class for fixed shape tensor extension arrays.
Expand Down Expand Up @@ -3433,6 +3450,9 @@ class FixedShapeTensorArray(ExtensionArray):
FixedSizeListArray.from_arrays(np.ravel(obj, order='C'), size)
)

def __arrow_ext_scalar_class__(self):
return FixedShapeTensorScalarType


cdef dict _array_classes = {
_Type_NA: NullArray,
Expand Down
3 changes: 3 additions & 0 deletions python/pyarrow/tests/test_extension_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,6 +1247,9 @@ def test_tensor_class_methods():
with pytest.raises(ValueError, match="non-permuted tensors"):
arr.to_numpy_ndarray()

for i in range(expected.shape[0]):
assert arr[i].to_numpy_ndarray() == expected[i]


@pytest.mark.parametrize("tensor_type", (
pa.fixed_shape_tensor(pa.int8(), [2, 2, 3]),
Expand Down

0 comments on commit ccf3696

Please sign in to comment.