From 9fba756366867ac8ff29ba11ce8f03c920360a36 Mon Sep 17 00:00:00 2001 From: Alexey Kudinkin Date: Wed, 13 Nov 2024 11:42:22 -0800 Subject: [PATCH] [Arrow] Enabling V2 Arrow Tensor extension type by default (allowing tensors > 2Gb) (#48629) Enabling V2 Arrow Tensor extension type by default (allowing tensors > 2Gb) --------- Signed-off-by: Alexey Kudinkin --- python/ray/air/util/data_batch_conversion.py | 8 +++- .../ray/air/util/tensor_extensions/arrow.py | 29 +++++++++++--- .../_internal/arrow_ops/transform_pyarrow.py | 17 ++++---- python/ray/data/context.py | 2 +- python/ray/data/dataset.py | 13 +++++-- python/ray/data/tests/test_ecosystem.py | 6 ++- python/ray/data/tests/test_image.py | 8 ++-- python/ray/data/tests/test_numpy.py | 39 ++++++++++++------- python/ray/data/tests/test_pandas.py | 7 +++- python/ray/data/tests/test_parquet.py | 8 +++- python/ray/data/tests/test_strict_mode.py | 13 ++++++- .../ray/data/tests/test_transform_pyarrow.py | 14 ++++++- 12 files changed, 116 insertions(+), 48 deletions(-) diff --git a/python/ray/air/util/data_batch_conversion.py b/python/ray/air/util/data_batch_conversion.py index e134b5b1d31f..4fe7a8ab2ea9 100644 --- a/python/ray/air/util/data_batch_conversion.py +++ b/python/ray/air/util/data_batch_conversion.py @@ -6,6 +6,9 @@ from ray.air.constants import TENSOR_COLUMN_NAME from ray.air.data_batch_type import DataBatchType +from ray.air.util.tensor_extensions.arrow import ( + get_arrow_extension_fixed_shape_tensor_types, +) from ray.util.annotations import Deprecated, DeveloperAPI if TYPE_CHECKING: @@ -217,14 +220,15 @@ def _convert_batch_type_to_numpy( ) return data elif pyarrow is not None and isinstance(data, pyarrow.Table): - from ray.air.util.tensor_extensions.arrow import ArrowTensorType from ray.air.util.transform_pyarrow import ( _concatenate_extension_column, _is_column_extension_type, ) if data.column_names == [TENSOR_COLUMN_NAME] and ( - isinstance(data.schema.types[0], ArrowTensorType) + isinstance( + data.schema.types[0], get_arrow_extension_fixed_shape_tensor_types() + ) ): # If representing a tensor dataset, return as a single numpy array. # Example: ray.data.from_numpy(np.arange(12).reshape((3, 2, 2))) diff --git a/python/ray/air/util/tensor_extensions/arrow.py b/python/ray/air/util/tensor_extensions/arrow.py index 4b6940748a9c..bef940c136f9 100644 --- a/python/ray/air/util/tensor_extensions/arrow.py +++ b/python/ray/air/util/tensor_extensions/arrow.py @@ -100,7 +100,26 @@ def get_arrow_extension_tensor_types(): """Returns list of extension types of Arrow Array holding multidimensional tensors """ - return ArrowTensorType, ArrowTensorTypeV2, ArrowVariableShapedTensorType + return ( + *get_arrow_extension_fixed_shape_tensor_types(), + *get_arrow_extension_variable_shape_tensor_types(), + ) + + +@DeveloperAPI +def get_arrow_extension_fixed_shape_tensor_types(): + """Returns list of Arrow extension types holding multidimensional + tensors of *fixed* shape + """ + return ArrowTensorType, ArrowTensorTypeV2 + + +@DeveloperAPI +def get_arrow_extension_variable_shape_tensor_types(): + """Returns list of Arrow extension types holding multidimensional + tensors of *fixed* shape + """ + return (ArrowVariableShapedTensorType,) class _BaseFixedShapeArrowTensorType(pa.ExtensionType, abc.ABC): @@ -225,7 +244,7 @@ def _need_variable_shaped_tensor_array( # short-circuit since we require a variable-shaped representation. if isinstance(arr_type, ArrowVariableShapedTensorType): return True - if not isinstance(arr_type, (ArrowTensorType, ArrowTensorTypeV2)): + if not isinstance(arr_type, get_arrow_extension_fixed_shape_tensor_types()): raise ValueError( "All provided array types must be an instance of either " "ArrowTensorType or ArrowVariableShapedTensorType, but " @@ -469,9 +488,7 @@ def _from_numpy( from ray.data import DataContext - should_use_tensor_v2 = DataContext.get_current().use_arrow_tensor_v2 - - if should_use_tensor_v2: + if DataContext.get_current().use_arrow_tensor_v2: pa_type_ = ArrowTensorTypeV2(element_shape, scalar_dtype) else: pa_type_ = ArrowTensorType(element_shape, scalar_dtype) @@ -635,7 +652,7 @@ def _chunk_tensor_arrays( if ArrowTensorType._need_variable_shaped_tensor_array(arrs_types): new_arrs = [] for a in arrs: - if isinstance(a.type, (ArrowTensorType, ArrowTensorTypeV2)): + if isinstance(a.type, get_arrow_extension_fixed_shape_tensor_types()): a = a.to_variable_shaped_tensor_array() assert isinstance(a.type, ArrowVariableShapedTensorType) new_arrs.append(a) diff --git a/python/ray/data/_internal/arrow_ops/transform_pyarrow.py b/python/ray/data/_internal/arrow_ops/transform_pyarrow.py index 749f03ed440d..093588ca8f34 100644 --- a/python/ray/data/_internal/arrow_ops/transform_pyarrow.py +++ b/python/ray/data/_internal/arrow_ops/transform_pyarrow.py @@ -3,7 +3,6 @@ from packaging.version import parse as parse_version from ray._private.utils import _get_pyarrow_version -from ray.air.util.tensor_extensions.arrow import ArrowTensorTypeV2 try: import pyarrow @@ -90,11 +89,14 @@ def unify_schemas( cols_with_null_list.add(col_name) all_columns.add(col_name) - arrow_tensor_types = ( - ArrowVariableShapedTensorType, - ArrowTensorType, - ArrowTensorTypeV2, + from ray.air.util.tensor_extensions.arrow import ( + get_arrow_extension_fixed_shape_tensor_types, + get_arrow_extension_tensor_types, ) + + arrow_tensor_types = get_arrow_extension_tensor_types() + arrow_fixed_shape_tensor_types = get_arrow_extension_fixed_shape_tensor_types() + columns_with_objects = set() columns_with_tensor_array = set() for col_name in all_columns: @@ -124,12 +126,11 @@ def unify_schemas( for s in schemas if isinstance(s.field(col_name).type, arrow_tensor_types) ] + if ArrowTensorType._need_variable_shaped_tensor_array(tensor_array_types): if isinstance(tensor_array_types[0], ArrowVariableShapedTensorType): new_type = tensor_array_types[0] - elif isinstance( - tensor_array_types[0], (ArrowTensorType, ArrowTensorTypeV2) - ): + elif isinstance(tensor_array_types[0], arrow_fixed_shape_tensor_types): new_type = ArrowVariableShapedTensorType( dtype=tensor_array_types[0].scalar_type, ndim=len(tensor_array_types[0].shape), diff --git a/python/ray/data/context.py b/python/ray/data/context.py index f43fc2a50246..5ed9b4fe68ef 100644 --- a/python/ray/data/context.py +++ b/python/ray/data/context.py @@ -78,7 +78,7 @@ # total cumulative size (due to it internally utilizing int32 offsets) # # V2 in turn relies on int64 offsets, therefore having a limit of ~9Eb (exabytes) -DEFAULT_USE_ARROW_TENSOR_V2 = env_bool("RAY_DATA_USE_ARROW_TENSOR_V2", False) +DEFAULT_USE_ARROW_TENSOR_V2 = env_bool("RAY_DATA_USE_ARROW_TENSOR_V2", True) DEFAULT_AUTO_LOG_STATS = False diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index c3fed90e4afc..d576b8eb2ea7 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -28,7 +28,10 @@ import ray.cloudpickle as pickle from ray._private.thirdparty.tabulate.tabulate import tabulate from ray._private.usage import usage_lib -from ray.air.util.tensor_extensions.arrow import ArrowTensorTypeV2 +from ray.air.util.tensor_extensions.arrow import ( + ArrowTensorTypeV2, + get_arrow_extension_fixed_shape_tensor_types, +) from ray.air.util.tensor_extensions.utils import _create_possibly_ragged_ndarray from ray.data._internal.aggregate import Max, Mean, Min, Std, Sum from ray.data._internal.compute import ComputeStrategy @@ -4478,15 +4481,17 @@ def block_to_df(block_ref: ObjectRef[Block]) -> pd.DataFrame: } ) elif pa is not None and isinstance(schema, pa.Schema): - from ray.data.extensions import ArrowTensorType + arrow_tensor_ext_types = get_arrow_extension_fixed_shape_tensor_types() - if any(isinstance(type_, ArrowTensorType) for type_ in schema.types): + if any( + isinstance(type_, arrow_tensor_ext_types) for type_ in schema.types + ): meta = pd.DataFrame( { col: pd.Series( dtype=( dtype.to_pandas_dtype() - if not isinstance(dtype, ArrowTensorType) + if not isinstance(dtype, arrow_tensor_ext_types) else np.object_ ) ) diff --git a/python/ray/data/tests/test_ecosystem.py b/python/ray/data/tests/test_ecosystem.py index 4d124907ff44..32c17ada95f4 100644 --- a/python/ray/data/tests/test_ecosystem.py +++ b/python/ray/data/tests/test_ecosystem.py @@ -6,9 +6,11 @@ import pytest import ray +from ray.air.util.tensor_extensions.arrow import ( + get_arrow_extension_fixed_shape_tensor_types, +) from ray.data.extensions.tensor_extension import ( ArrowTensorArray, - ArrowTensorType, TensorArray, TensorDtype, ) @@ -119,7 +121,7 @@ def test_to_dask_tensor_column_cast_arrow(ray_start_regular_shared): in_table = pa.table({"a": ArrowTensorArray.from_numpy(data)}) ds = ray.data.from_arrow(in_table) dtype = ds.schema().base_schema.field(0).type - assert isinstance(dtype, ArrowTensorType) + assert isinstance(dtype, get_arrow_extension_fixed_shape_tensor_types()) out_df = ds.to_dask().compute() assert out_df["a"].dtype.type is np.object_ expected_df = pd.DataFrame({"a": list(data)}) diff --git a/python/ray/data/tests/test_image.py b/python/ray/data/tests/test_image.py index 67612a5a6cbf..60fa6673d0a1 100644 --- a/python/ray/data/tests/test_image.py +++ b/python/ray/data/tests/test_image.py @@ -9,13 +9,15 @@ from PIL import Image import ray +from ray.air.util.tensor_extensions.arrow import ( + get_arrow_extension_fixed_shape_tensor_types, +) from ray.data._internal.datasource.image_datasource import ( ImageDatasource, ImageFileMetadataProvider, ) from ray.data.datasource import Partitioning from ray.data.datasource.file_meta_provider import FastFileMetadataProvider -from ray.data.extensions import ArrowTensorType from ray.data.tests.conftest import * # noqa from ray.data.tests.mock_http_server import * # noqa from ray.tests.conftest import * # noqa @@ -27,7 +29,7 @@ def test_basic(self, ray_start_regular_shared): ds = ray.data.read_images("example://image-datasets/simple") assert ds.schema().names == ["image"] column_type = ds.schema().types[0] - assert isinstance(column_type, ArrowTensorType) + assert isinstance(column_type, get_arrow_extension_fixed_shape_tensor_types()) assert all(record["image"].shape == (32, 32, 3) for record in ds.take()) @pytest.mark.parametrize("num_threads", [-1, 0, 1, 2, 4]) @@ -139,7 +141,7 @@ def test_partitioning( assert ds.schema().names == ["image", "label"] image_type, label_type = ds.schema().types - assert isinstance(image_type, ArrowTensorType) + assert isinstance(image_type, get_arrow_extension_fixed_shape_tensor_types()) assert pa.types.is_string(label_type) df = ds.to_pandas() diff --git a/python/ray/data/tests/test_numpy.py b/python/ray/data/tests/test_numpy.py index 2e6142b393d1..7b1fd5c1d3dc 100644 --- a/python/ray/data/tests/test_numpy.py +++ b/python/ray/data/tests/test_numpy.py @@ -7,7 +7,8 @@ from pytest_lazyfixture import lazy_fixture import ray -from ray.data import Schema +from ray.air.util.tensor_extensions.arrow import ArrowTensorTypeV2 +from ray.data import DataContext, Schema from ray.data.datasource import ( BaseFileMetadataProvider, FastFileMetadataProvider, @@ -23,6 +24,14 @@ from ray.tests.conftest import * # noqa +def _get_tensor_type(): + return ( + ArrowTensorTypeV2 + if DataContext.get_current().use_arrow_tensor_v2 + else ArrowTensorType + ) + + def test_numpy_read_partitioning(ray_start_regular_shared, tmp_path): path = os.path.join(tmp_path, "country=us", "data.npy") os.mkdir(os.path.dirname(path)) @@ -113,27 +122,27 @@ def test_to_numpy_refs(ray_start_regular_shared): ], ) def test_numpy_roundtrip(ray_start_regular_shared, fs, data_path): + tensor_type = _get_tensor_type() + ds = ray.data.range_tensor(10, override_num_blocks=2) ds.write_numpy(data_path, filesystem=fs, column="data") ds = ray.data.read_numpy(data_path, filesystem=fs) assert ds.count() == 10 - assert ds.schema() == Schema( - pa.schema([("data", ArrowTensorType((1,), pa.int64()))]) - ) + assert ds.schema() == Schema(pa.schema([("data", tensor_type((1,), pa.int64()))])) assert sorted(ds.take_all(), key=lambda row: row["data"]) == [ {"data": np.array([i])} for i in range(10) ] -def test_numpy_read(ray_start_regular_shared, tmp_path): +def test_numpy_read_x(ray_start_regular_shared, tmp_path): + tensor_type = _get_tensor_type() + path = os.path.join(tmp_path, "test_np_dir") os.mkdir(path) np.save(os.path.join(path, "test.npy"), np.expand_dims(np.arange(0, 10), 1)) ds = ray.data.read_numpy(path, override_num_blocks=1) assert ds.count() == 10 - assert ds.schema() == Schema( - pa.schema([("data", ArrowTensorType((1,), pa.int64()))]) - ) + assert ds.schema() == Schema(pa.schema([("data", tensor_type((1,), pa.int64()))])) np.testing.assert_equal( extract_values("data", ds.take(2)), [np.array([0]), np.array([1])] ) @@ -145,9 +154,7 @@ def test_numpy_read(ray_start_regular_shared, tmp_path): ds = ray.data.read_numpy(path, override_num_blocks=1) assert ds._plan.initial_num_blocks() == 1 assert ds.count() == 10 - assert ds.schema() == Schema( - pa.schema([("data", ArrowTensorType((1,), pa.int64()))]) - ) + assert ds.schema() == Schema(pa.schema([("data", tensor_type((1,), pa.int64()))])) assert [v["data"].item() for v in ds.take(2)] == [0, 1] @@ -174,6 +181,8 @@ def test_numpy_read_ignore_missing_paths( def test_numpy_read_meta_provider(ray_start_regular_shared, tmp_path): + tensor_type = _get_tensor_type() + path = os.path.join(tmp_path, "test_np_dir") os.mkdir(path) path = os.path.join(path, "test.npy") @@ -182,9 +191,7 @@ def test_numpy_read_meta_provider(ray_start_regular_shared, tmp_path): path, meta_provider=FastFileMetadataProvider(), override_num_blocks=1 ) assert ds.count() == 10 - assert ds.schema() == Schema( - pa.schema([("data", ArrowTensorType((1,), pa.int64()))]) - ) + assert ds.schema() == Schema(pa.schema([("data", tensor_type((1,), pa.int64()))])) np.testing.assert_equal( extract_values("data", ds.take(2)), [np.array([0]), np.array([1])] ) @@ -204,6 +211,8 @@ def test_numpy_read_partitioned_with_filter( write_partitioned_df, assert_base_partitioned_ds, ): + tensor_type = _get_tensor_type() + def df_to_np(dataframe, path, **kwargs): np.save(path, dataframe.to_numpy(dtype=np.dtype(np.int8)), **kwargs) @@ -245,7 +254,7 @@ def sorted_values_transform_fn(sorted_values): val_str = "".join(f"array({v}, dtype=int8), " for v in vals)[:-2] assert_base_partitioned_ds( ds, - schema=Schema(pa.schema([("data", ArrowTensorType((2,), pa.int8()))])), + schema=Schema(pa.schema([("data", tensor_type((2,), pa.int8()))])), sorted_values=f"[[{val_str}]]", ds_take_transform_fn=lambda taken: [extract_values("data", taken)], sorted_values_transform_fn=sorted_values_transform_fn, diff --git a/python/ray/data/tests/test_pandas.py b/python/ray/data/tests/test_pandas.py index 383d20f55851..81a6c6eea8d5 100644 --- a/python/ray/data/tests/test_pandas.py +++ b/python/ray/data/tests/test_pandas.py @@ -6,9 +6,12 @@ import pytest import ray +from ray.air.util.tensor_extensions.arrow import ( + get_arrow_extension_fixed_shape_tensor_types, +) from ray.data._internal.execution.interfaces.ref_bundle import RefBundle from ray.data.block import Block -from ray.data.extensions import ArrowTensorArray, ArrowTensorType, TensorDtype +from ray.data.extensions import ArrowTensorArray, TensorDtype from ray.data.tests.conftest import * # noqa from ray.data.tests.mock_http_server import * # noqa from ray.tests.conftest import * # noqa @@ -186,7 +189,7 @@ def test_to_pandas_tensor_column_cast_arrow(ray_start_regular_shared): in_table = pa.table({"a": ArrowTensorArray.from_numpy(data)}) ds = ray.data.from_arrow(in_table) dtype = ds.schema().base_schema.field(0).type - assert isinstance(dtype, ArrowTensorType) + assert isinstance(dtype, get_arrow_extension_fixed_shape_tensor_types()) out_df = ds.to_pandas() assert out_df["a"].dtype.type is np.object_ expected_df = pd.DataFrame({"a": list(data)}) diff --git a/python/ray/data/tests/test_parquet.py b/python/ray/data/tests/test_parquet.py index 6ff18ed45d7d..739edb1ddd0b 100644 --- a/python/ray/data/tests/test_parquet.py +++ b/python/ray/data/tests/test_parquet.py @@ -12,7 +12,10 @@ from pytest_lazyfixture import lazy_fixture import ray -from ray.air.util.tensor_extensions.arrow import ArrowTensorType, ArrowTensorTypeV2 +from ray.air.util.tensor_extensions.arrow import ( + ArrowTensorTypeV2, + get_arrow_extension_fixed_shape_tensor_types, +) from ray.data import Schema from ray.data._internal.datasource.parquet_bulk_datasource import ParquetBulkDatasource from ray.data._internal.datasource.parquet_datasource import ( @@ -1245,7 +1248,8 @@ def test_tensors_in_tables_parquet( ) assert isinstance( - ds.schema().base_schema.field_by_name(tensor_col_name).type, ArrowTensorType + ds.schema().base_schema.field_by_name(tensor_col_name).type, + get_arrow_extension_fixed_shape_tensor_types(), ) expected_tuples = list(zip(id_vals, group_vals, arr)) diff --git a/python/ray/data/tests/test_strict_mode.py b/python/ray/data/tests/test_strict_mode.py index cc96a517a548..49b4b9cc4e37 100644 --- a/python/ray/data/tests/test_strict_mode.py +++ b/python/ray/data/tests/test_strict_mode.py @@ -219,12 +219,21 @@ def test_strict_schema(ray_start_regular_shared): schema = ds.schema() assert isinstance(schema.base_schema, pa.lib.Schema) assert schema.names == ["data"] - assert schema.types == [ArrowTensorType(shape=(10,), dtype=pa.float64())] + + from ray.air.util.tensor_extensions.arrow import ArrowTensorTypeV2 + from ray.data import DataContext + + if DataContext.get_current().use_arrow_tensor_v2: + expected_arrow_ext_type = ArrowTensorTypeV2(shape=(10,), dtype=pa.float64()) + else: + expected_arrow_ext_type = ArrowTensorType(shape=(10,), dtype=pa.float64()) + + assert schema.types == [expected_arrow_ext_type] schema = ds.map_batches(lambda x: x, batch_format="pandas").schema() assert isinstance(schema.base_schema, PandasBlockSchema) assert schema.names == ["data"] - assert schema.types == [ArrowTensorType(shape=(10,), dtype=pa.float64())] + assert schema.types == [expected_arrow_ext_type] def test_use_raw_dicts(ray_start_regular_shared): diff --git a/python/ray/data/tests/test_transform_pyarrow.py b/python/ray/data/tests/test_transform_pyarrow.py index ff3d1dbf6610..570bd8f6592b 100644 --- a/python/ray/data/tests/test_transform_pyarrow.py +++ b/python/ray/data/tests/test_transform_pyarrow.py @@ -7,6 +7,8 @@ import pytest import ray +from ray.air.util.tensor_extensions.arrow import ArrowTensorTypeV2 +from ray.data import DataContext from ray.data._internal.arrow_ops.transform_pyarrow import concat, unify_schemas from ray.data.block import BlockAccessor from ray.data.extensions import ( @@ -87,17 +89,27 @@ def test_arrow_concat_tensor_extension_uniform(): t2 = pa.table({"a": ArrowTensorArray.from_numpy(a2)}) ts = [t1, t2] out = concat(ts) + # Check length. assert len(out) == 6 + # Check schema. + if DataContext.get_current().use_arrow_tensor_v2: + tensor_type = ArrowTensorTypeV2 + else: + tensor_type = ArrowTensorType + assert out.column_names == ["a"] - assert out.schema.types == [ArrowTensorType((2, 2), pa.int64())] + assert out.schema.types == [tensor_type((2, 2), pa.int64())] + # Confirm that concatenation is zero-copy (i.e. it didn't trigger chunk # consolidation). assert out["a"].num_chunks == 2 + # Check content. np.testing.assert_array_equal(out["a"].chunk(0).to_numpy(), a1) np.testing.assert_array_equal(out["a"].chunk(1).to_numpy(), a2) + # Check equivalence. expected = pa.concat_tables(ts, promote=True) assert out == expected