Skip to content

Commit

Permalink
Use GetTensor to get numpy values.
Browse files Browse the repository at this point in the history
  • Loading branch information
rok committed Sep 25, 2023
1 parent 18a60d0 commit 151da9d
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 36 deletions.
19 changes: 19 additions & 0 deletions cpp/src/arrow/extension/fixed_shape_tensor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ TEST_F(TestExtensionType, GetScalar) {
}

TEST_F(TestExtensionType, GetTensor) {
// Get tensor from extension array
auto ext_type = fixed_shape_tensor(value_type_, cell_shape_, {}, dim_names_);
auto arr = ArrayFromJSON(cell_type_,
"[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],"
Expand All @@ -478,8 +479,26 @@ TEST_F(TestExtensionType, GetTensor) {
ASSERT_EQ(expected_tensor->shape(), actual_tensor->shape());
ASSERT_EQ(expected_tensor->dim_names(), actual_tensor->dim_names());
ASSERT_EQ(expected_tensor->strides(), actual_tensor->strides());
ASSERT_EQ(actual_tensor->strides(), std::vector<int64_t>({32, 8}));
ASSERT_EQ(expected_tensor->type(), actual_tensor->type());
ASSERT_TRUE(expected_tensor->Equals(*actual_tensor));

// Get tensor from extension array with non-trivial permutation
auto permuted_ext_type = fixed_shape_tensor(value_type_, {3, 4}, {1, 0}, {"x", "y"});
auto permuted_array = std::static_pointer_cast<FixedShapeTensorArray>(
ExtensionType::WrapArray(permuted_ext_type, arr));

std::vector<int64_t> values_second_cell = {12, 13, 14, 15, 16, 17, 18, 19, 20, 21};
ASSERT_OK_AND_ASSIGN(auto expected_permuted_tensor,
Tensor::Make(value_type_, Buffer::Wrap(values_second_cell), {4, 3},
{8, 24}, {"y", "x"}));

ASSERT_OK_AND_ASSIGN(auto actual_permuted_tensor, permuted_array->GetTensor(1));
ASSERT_EQ(expected_permuted_tensor->shape(), actual_permuted_tensor->shape());
ASSERT_EQ(expected_permuted_tensor->dim_names(), actual_permuted_tensor->dim_names());
ASSERT_EQ(expected_permuted_tensor->strides(), actual_permuted_tensor->strides());
ASSERT_EQ(expected_permuted_tensor->type(), actual_permuted_tensor->type());
ASSERT_TRUE(expected_permuted_tensor->Equals(*actual_permuted_tensor));
}

} // namespace arrow
46 changes: 17 additions & 29 deletions python/pyarrow/array.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -3325,27 +3325,7 @@ cdef class ExtensionArray(Array):
return result


class FixedShapeTensorScalar(ExtensionScalar):
"""
Concrete class for fixed shape tensor extension scalar.
"""

def to_numpy_ndarray(self):
"""
Convert fixed shape tensor extension scalar to a numpy array (with dim).
Note: ``permutation`` should be trivial (``None`` or ``[0, 1, ..., len(shape)-1]``).
"""
if self.type.permutation is None or self.type.permutation == list(range(len(self.type.shape))):
np_flat = self.value.values.to_numpy()
numpy_tensor = np_flat.reshape(tuple(self.type.shape))
return numpy_tensor
else:
raise ValueError(
'Only non-permuted tensors can be converted to numpy tensors.')


class FixedShapeTensorArray(ExtensionArray):
cdef class FixedShapeTensorArray(ExtensionArray):
"""
Concrete class for fixed shape tensor extension arrays.
Expand Down Expand Up @@ -3387,16 +3367,24 @@ class FixedShapeTensorArray(ExtensionArray):
def to_numpy_ndarray(self):
"""
Convert fixed shape tensor extension array to a numpy array (with dim+1).
"""
cdef:
CFixedShapeTensorArray* ext_array = <CFixedShapeTensorArray*>(self.ap)
CResult[shared_ptr[CTensor]] ctensor
with nogil:
ctensor = ext_array.ToTensor()
return pyarrow_wrap_tensor(GetResultValue(ctensor)).to_numpy()

Note: ``permutation`` should be trivial (``None`` or ``[0, 1, ..., len(shape)-1]``).
def get_tensor(self, int64_t i):
"""
if self.type.permutation is None or self.type.permutation == list(range(len(self.type.shape))):
np_flat = np.asarray(self.storage.flatten())
numpy_tensor = np_flat.reshape((len(self),) + tuple(self.type.shape))
return numpy_tensor
else:
raise ValueError(
'Only non-permuted tensors can be converted to numpy tensors.')
Convert variable shape tensor extension array to list of numpy arrays.
"""
cdef:
CFixedShapeTensorArray* ext_array = <CFixedShapeTensorArray*>(self.ap)
CResult[shared_ptr[CTensor]] ctensor
with nogil:
ctensor = ext_array.GetTensor(i)
return pyarrow_wrap_tensor(GetResultValue(ctensor))

@staticmethod
def from_numpy_ndarray(obj):
Expand Down
4 changes: 4 additions & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -2654,6 +2654,10 @@ cdef extern from "arrow/extension/fixed_shape_tensor.h" namespace "arrow::extens
const vector[int64_t] permutation()
const vector[c_string] dim_names()

cdef cppclass CFixedShapeTensorArray \
" arrow::extension::FixedShapeTensorArray"(CExtensionArray) nogil:
const CResult[shared_ptr[CTensor]] GetTensor(const int64_t i) const
const CResult[shared_ptr[CTensor]] ToTensor() const

cdef extern from "arrow/util/compression.h" namespace "arrow" nogil:
cdef enum CCompressionType" arrow::Compression::type":
Expand Down
22 changes: 22 additions & 0 deletions python/pyarrow/scalar.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,28 @@ cdef class ExtensionScalar(Scalar):
return pyarrow_wrap_scalar(<shared_ptr[CScalar]> sp_scalar)


class FixedShapeTensorScalar(ExtensionScalar):
"""
Concrete class for fixed shape tensor extension scalar.
"""

def to_numpy_ndarray(self):
# TODO: allow any permutation
"""
Convert fixed shape tensor extension scalar to a numpy array (with dim).
Note: ``permutation`` should be trivial (``None`` or ``[0, 1, ..., len(shape)-1]``).
"""

if self.type.permutation is None or self.type.permutation == list(range(len(self.type.shape))):
np_flat = np.asarray(self.storage.flatten())
numpy_tensor = np_flat.reshape(tuple(self.type.shape))
return numpy_tensor
else:
raise ValueError(
'Only non-permuted tensors can be converted to numpy tensors.')


cdef dict _scalar_classes = {
_Type_BOOL: BooleanScalar,
_Type_UINT8: UInt8Scalar,
Expand Down
21 changes: 14 additions & 7 deletions python/pyarrow/tests/test_extension_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,37 +1215,44 @@ def test_tensor_type():

def test_tensor_class_methods():
tensor_type = pa.fixed_shape_tensor(pa.float32(), [2, 3])
storage = pa.array([[1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6]],
storage = pa.array([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]],
pa.list_(pa.float32(), 6))
arr = pa.ExtensionArray.from_storage(tensor_type, storage)

# TODO: add more get_tensor tests
assert arr.get_tensor(0) == pa.Tensor.from_numpy(
np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32))

expected = np.array(
[[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]], dtype=np.float32)
[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], dtype=np.float32)
result = arr.to_numpy_ndarray()
np.testing.assert_array_equal(result, expected)

expected = np.array([[[1, 2, 3], [4, 5, 6]]], dtype=np.float32)
expected = np.array([[[7, 8, 9], [10, 11, 12]]], dtype=np.float32)
result = arr[:1].to_numpy_ndarray()
np.testing.assert_array_equal(result, expected)

arr = np.array(
[[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]],
[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
dtype=np.float32, order="C")
tensor_array_from_numpy = pa.FixedShapeTensorArray.from_numpy_ndarray(arr)
assert isinstance(tensor_array_from_numpy.type, pa.FixedShapeTensorType)
assert tensor_array_from_numpy.type.value_type == pa.float32()
assert tensor_array_from_numpy.type.shape == [2, 3]

arr = np.array(
[[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]],
[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
dtype=np.float32, order="F")
with pytest.raises(ValueError, match="C-style contiguous segment"):
pa.FixedShapeTensorArray.from_numpy_ndarray(arr)

storage = pa.array([[1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6]], pa.list_(pa.int8(), 12))
storage = pa.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]],
pa.list_(pa.int8(), 12))
expected = np.array(
[[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]], dtype=np.int8)
[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], dtype=np.int8)

tensor_type = pa.fixed_shape_tensor(pa.int8(), [2, 2, 3], permutation=[0, 2, 1])
storage = pa.array([[1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6]], pa.list_(pa.int8(), 12))
arr = pa.ExtensionArray.from_storage(tensor_type, storage)
with pytest.raises(ValueError, match="non-permuted tensors"):
arr.to_numpy_ndarray()
Expand Down

0 comments on commit 151da9d

Please sign in to comment.