From aa26f28a64b7638f01756d78a2ea8fbddceafc65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ra=C3=BAl=20Cumplido?= Date: Mon, 18 Nov 2024 12:10:57 +0100 Subject: [PATCH] GH-44651: [Python] Allow from_buffers to work with StringView on Python (#44701) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Rationale for this change Currently `from_buffers` is not working with StringView on Python because we validate against num_buffers. This only take into account the mandatory buffers but does not take into account the variadic_spec that can be present for both string_view and binary_view ### What changes are included in this PR? Take into account whether the type contains a variadic_spec for the non-mandatory buffers and only check lower_bound number of buffers. ### Are these changes tested? Yes, I've added a couple of tests. ### Are there any user-facing changes? We are exposing a new method on the Python DataType. `has_variadic_buffers` which tells us whether the number of buffers expected is only lower-bounded by num_buffers. * GitHub Issue: #44651 Authored-by: Raúl Cumplido Signed-off-by: Raúl Cumplido --- python/pyarrow/array.pxi | 7 ++++++- python/pyarrow/includes/libarrow.pxd | 1 + python/pyarrow/tests/test_array.py | 26 ++++++++++++++++++++++++++ python/pyarrow/tests/test_types.py | 8 ++++++++ python/pyarrow/types.pxi | 16 ++++++++++++++++ 5 files changed, 57 insertions(+), 1 deletion(-) diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index eaedbf1e38580..8bddc34e1000b 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -1174,7 +1174,12 @@ cdef class Array(_PandasConvertible): "({0}) did not match the passed number " "({1}).".format(type.num_fields, len(children))) - if type.num_buffers != len(buffers): + if type.has_variadic_buffers: + if type.num_buffers > len(buffers): + raise ValueError("Type's expected number of buffers is at least " + "{0}, but the passed number is " + "{1}.".format(type.num_buffers, len(buffers))) + elif type.num_buffers != len(buffers): raise ValueError("Type's expected number of buffers " "({0}) did not match the passed number " "({1}).".format(type.num_buffers, len(buffers))) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index a70cb91873e45..8bf61b73cc211 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -158,6 +158,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: cdef cppclass CDataTypeLayout" arrow::DataTypeLayout": vector[CBufferSpec] buffers + optional[CBufferSpec] variadic_spec c_bool has_dictionary cdef cppclass CDataType" arrow::DataType": diff --git a/python/pyarrow/tests/test_array.py b/python/pyarrow/tests/test_array.py index 4160d64829483..885442b079c5b 100644 --- a/python/pyarrow/tests/test_array.py +++ b/python/pyarrow/tests/test_array.py @@ -651,6 +651,32 @@ def test_string_binary_from_buffers(): assert copied.null_count == 0 +def test_string_view_from_buffers(): + array = pa.array( + [ + "String longer than 12 characters", + None, + "short", + "Length is 12" + ], type=pa.string_view()) + + buffers = array.buffers() + copied = pa.StringViewArray.from_buffers( + pa.string_view(), len(array), buffers) + copied.validate(full=True) + assert copied.to_pylist() == [ + "String longer than 12 characters", + None, + "short", + "Length is 12" + ] + + match = r"number of buffers is at least 2" + with pytest.raises(ValueError, match=match): + pa.StringViewArray.from_buffers( + pa.string_view(), len(array), buffers[0:1]) + + @pytest.mark.parametrize('list_type_factory', [ pa.list_, pa.large_list, pa.list_view, pa.large_list_view]) def test_list_from_buffers(list_type_factory): diff --git a/python/pyarrow/tests/test_types.py b/python/pyarrow/tests/test_types.py index fef350d5de958..de439b6bb8cd7 100644 --- a/python/pyarrow/tests/test_types.py +++ b/python/pyarrow/tests/test_types.py @@ -887,6 +887,14 @@ def test_types_weakref(): assert wr() is None # not a singleton +def test_types_has_variadic_buffers(): + for ty in get_many_types(): + if ty in (pa.string_view(), pa.binary_view()): + assert ty.has_variadic_buffers + else: + assert not ty.has_variadic_buffers + + def test_fields_hashable(): in_dict = {} fields = [pa.field('a', pa.int32()), diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index 4aa8238556a9c..0d6787cf2a049 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -326,6 +326,22 @@ cdef class DataType(_Weakrefable): """ return self.type.layout().buffers.size() + @property + def has_variadic_buffers(self): + """ + If True, the number of expected buffers is only + lower-bounded by num_buffers. + + Examples + -------- + >>> import pyarrow as pa + >>> pa.int64().has_variadic_buffers + False + >>> pa.string_view().has_variadic_buffers + True + """ + return self.type.layout().variadic_spec.has_value() + def __str__(self): return frombytes(self.type.ToString(), safe=True)