diff --git a/python/ray/data/_internal/arrow_block.py b/python/ray/data/_internal/arrow_block.py index 2078aea57575..9a446cbba37c 100644 --- a/python/ray/data/_internal/arrow_block.py +++ b/python/ray/data/_internal/arrow_block.py @@ -72,31 +72,51 @@ class ArrowRow(TableRow): Row of a tabular Dataset backed by a Arrow Table block. """ - def __getitem__(self, key: str) -> Any: + def __getitem__(self, key: Union[str, List[str]]) -> Any: from ray.data.extensions.tensor_extension import ( ArrowTensorType, ArrowVariableShapedTensorType, ) - schema = self._row.schema - if isinstance( - schema.field(key).type, - (ArrowTensorType, ArrowVariableShapedTensorType), - ): - # Build a tensor row. - return ArrowBlockAccessor._build_tensor_row(self._row, col_name=key) + def get_item(keys: List[str]) -> Any: + schema = self._row.schema + if isinstance( + schema.field(keys[0]).type, + (ArrowTensorType, ArrowVariableShapedTensorType), + ): + # Build a tensor row. + return tuple( + [ + ArrowBlockAccessor._build_tensor_row(self._row, col_name=key) + for key in keys + ] + ) + + table = self._row.select(keys) + if len(table) == 0: + return None + + items = [col[0] for col in table.columns] + try: + # Try to interpret this as a pyarrow.Scalar value. + return tuple([item.as_py() for item in items]) + + except AttributeError: + # Assume that this row is an element of an extension array, and + # that it is bypassing pyarrow's scalar model for Arrow < 8.0.0. + return items - col = self._row[key] - if len(col) == 0: + is_single_item = isinstance(key, str) + keys = [key] if is_single_item else key + + items = get_item(keys) + + if items is None: return None - item = col[0] - try: - # Try to interpret this as a pyarrow.Scalar value. - return item.as_py() - except AttributeError: - # Assume that this row is an element of an extension array, and - # that it is bypassing pyarrow's scalar model for Arrow < 8.0.0. - return item + elif is_single_item: + return items[0] + else: + return items def __iter__(self) -> Iterator: for k in self._row.column_names: @@ -428,13 +448,15 @@ def sort_and_partition( return find_partitions(table, boundaries, sort_key) - def combine(self, key: str, aggs: Tuple["AggregateFn"]) -> Block: + def combine(self, key: Union[str, List[str]], aggs: Tuple["AggregateFn"]) -> Block: """Combine rows with the same key into an accumulator. This assumes the block is already sorted by key in ascending order. Args: - key: The column name of key or None for global aggregation. + key: A column name or list of column names. + If this is ``None``, place all rows in a single group. + aggs: The aggregations to do. Returns: @@ -443,9 +465,10 @@ def combine(self, key: str, aggs: Tuple["AggregateFn"]) -> Block: aggregation. If key is None then the k column is omitted. """ - if key is not None and not isinstance(key, str): + if key is not None and not isinstance(key, (str, list)): raise ValueError( - "key must be a string or None when aggregating on Arrow blocks, but " + "key must be a string, list of strings or None when aggregating " + "on Arrow blocks, but " f"got: {type(key)}." ) @@ -486,7 +509,15 @@ def iter_groups() -> Iterator[Tuple[KeyType, Block]]: # Build the row. row = {} if key is not None: - row[key] = group_key + if isinstance(key, list): + keys = key + group_keys = group_key + else: + keys = [key] + group_keys = [group_key] + + for k, gk in zip(keys, group_keys): + row[k] = gk count = collections.defaultdict(int) for agg, accumulator in zip(aggs, accumulators): @@ -521,7 +552,7 @@ def merge_sorted_blocks( @staticmethod def aggregate_combined_blocks( blocks: List[Block], - key: str, + key: Union[str, List[str]], aggs: Tuple["AggregateFn"], finalize: bool, ) -> Tuple[Block, BlockMetadata]: @@ -546,8 +577,12 @@ def aggregate_combined_blocks( """ stats = BlockExecStats.builder() + + keys = key if isinstance(key, list) else [key] key_fn = ( - (lambda r: r[r._row.schema.names[0]]) if key is not None else (lambda r: 0) + (lambda r: tuple(r[r._row.schema.names[: len(keys)]])) + if key is not None + else (lambda r: (0,)) ) iter = heapq.merge( @@ -563,15 +598,15 @@ def aggregate_combined_blocks( try: if next_row is None: next_row = next(iter) - next_key = key_fn(next_row) - next_key_name = ( - next_row._row.schema.names[0] if key is not None else None + next_keys = key_fn(next_row) + next_key_names = ( + next_row._row.schema.names[: len(keys)] if key is not None else None ) def gen(): nonlocal iter nonlocal next_row - while key_fn(next_row) == next_key: + while key_fn(next_row) == next_keys: yield next_row try: next_row = next(iter) @@ -606,7 +641,8 @@ def gen(): # Build the row. row = {} if key is not None: - row[next_key_name] = next_key + for next_key, next_key_name in zip(next_keys, next_key_names): + row[next_key_name] = next_key for agg, agg_name, accumulator in zip( aggs, resolved_agg_names, accumulators diff --git a/python/ray/data/_internal/pandas_block.py b/python/ray/data/_internal/pandas_block.py index 87b1f323e148..54d1d8b409f7 100644 --- a/python/ray/data/_internal/pandas_block.py +++ b/python/ray/data/_internal/pandas_block.py @@ -55,24 +55,40 @@ class PandasRow(TableRow): Row of a tabular Dataset backed by a Pandas DataFrame block. """ - def __getitem__(self, key: str) -> Any: + def __getitem__(self, key: Union[str, List[str]]) -> Any: from ray.data.extensions import TensorArrayElement - col = self._row[key] - if len(col) == 0: + def get_item(keys: List[str]) -> Any: + col = self._row[keys] + if len(col) == 0: + return None + + items = col.iloc[0] + if isinstance(items[0], TensorArrayElement): + # Getting an item in a Pandas tensor column may return + # a TensorArrayElement, which we have to convert to an ndarray. + return tuple([item.to_numpy() for item in items]) + + try: + # Try to interpret this as a numpy-type value. + # See https://stackoverflow.com/questions/9452775/converting-numpy-dtypes-to-native-python-types. # noqa: E501 + return tuple([item.as_py() for item in items]) + + except (AttributeError, ValueError): + # Fallback to the original form. + return items + + is_single_item = isinstance(key, str) + keys = [key] if is_single_item else key + + items = get_item(keys) + + if items is None: return None - item = col.iloc[0] - if isinstance(item, TensorArrayElement): - # Getting an item in a Pandas tensor column may return a TensorArrayElement, - # which we have to convert to an ndarray. - item = item.to_numpy() - try: - # Try to interpret this as a numpy-type value. - # See https://stackoverflow.com/questions/9452775/converting-numpy-dtypes-to-native-python-types. # noqa: E501 - return item.item() - except (AttributeError, ValueError): - # Fallback to the original form. - return item + elif is_single_item: + return items[0] + else: + return items def __iter__(self) -> Iterator: for k in self._row.columns: @@ -359,13 +375,17 @@ def sort_and_partition( return find_partitions(table, boundaries, sort_key) - def combine(self, key: str, aggs: Tuple["AggregateFn"]) -> "pandas.DataFrame": + def combine( + self, key: Union[str, List[str]], aggs: Tuple["AggregateFn"] + ) -> "pandas.DataFrame": """Combine rows with the same key into an accumulator. This assumes the block is already sorted by key in ascending order. Args: - key: The column name of key or None for global aggregation. + key: A column name or list of column names. + If this is ``None``, place all rows in a single group. + aggs: The aggregations to do. Returns: @@ -374,9 +394,10 @@ def combine(self, key: str, aggs: Tuple["AggregateFn"]) -> "pandas.DataFrame": aggregation. If key is None then the k column is omitted. """ - if key is not None and not isinstance(key, str): + if key is not None and not isinstance(key, (str, list)): raise ValueError( - "key must be a string or None when aggregating on Pandas blocks, but " + "key must be a string, list of strings or None when aggregating " + "on Pandas blocks, but " f"got: {type(key)}." ) @@ -395,7 +416,7 @@ def iter_groups() -> Iterator[Tuple[KeyType, Block]]: if next_row is None: next_row = next(iter) next_key = next_row[key] - while next_row[key] == next_key: + while np.all(next_row[key] == next_key): end += 1 try: next_row = next(iter) @@ -417,7 +438,15 @@ def iter_groups() -> Iterator[Tuple[KeyType, Block]]: # Build the row. row = {} if key is not None: - row[key] = group_key + if isinstance(key, list): + keys = key + group_keys = group_key + else: + keys = [key] + group_keys = [group_key] + + for k, gk in zip(keys, group_keys): + row[k] = gk count = collections.defaultdict(int) for agg, accumulator in zip(aggs, accumulators): @@ -452,7 +481,7 @@ def merge_sorted_blocks( @staticmethod def aggregate_combined_blocks( blocks: List["pandas.DataFrame"], - key: str, + key: Union[str, List[str]], aggs: Tuple["AggregateFn"], finalize: bool, ) -> Tuple["pandas.DataFrame", BlockMetadata]: @@ -477,7 +506,12 @@ def aggregate_combined_blocks( """ stats = BlockExecStats.builder() - key_fn = (lambda r: r[r._row.columns[0]]) if key is not None else (lambda r: 0) + keys = key if isinstance(key, list) else [key] + key_fn = ( + (lambda r: tuple(r[r._row.columns[: len(keys)]])) + if key is not None + else (lambda r: (0,)) + ) iter = heapq.merge( *[ @@ -492,13 +526,15 @@ def aggregate_combined_blocks( try: if next_row is None: next_row = next(iter) - next_key = key_fn(next_row) - next_key_name = next_row._row.columns[0] if key is not None else None + next_keys = key_fn(next_row) + next_key_names = ( + next_row._row.columns[: len(keys)] if key is not None else None + ) def gen(): nonlocal iter nonlocal next_row - while key_fn(next_row) == next_key: + while key_fn(next_row) == next_keys: yield next_row try: next_row = next(iter) @@ -533,7 +569,8 @@ def gen(): # Build the row. row = {} if key is not None: - row[next_key_name] = next_key + for next_key, next_key_name in zip(next_keys, next_key_names): + row[next_key_name] = next_key for agg, agg_name, accumulator in zip( aggs, resolved_agg_names, accumulators diff --git a/python/ray/data/_internal/planner/exchange/aggregate_task_spec.py b/python/ray/data/_internal/planner/exchange/aggregate_task_spec.py index 16ea77b49975..a3a2293d77d6 100644 --- a/python/ray/data/_internal/planner/exchange/aggregate_task_spec.py +++ b/python/ray/data/_internal/planner/exchange/aggregate_task_spec.py @@ -41,7 +41,7 @@ def map( block: Block, output_num_blocks: int, boundaries: List[KeyType], - key: Optional[str], + key: Union[str, List[str], None], aggs: List[AggregateFn], ) -> List[Union[BlockMetadata, Block]]: stats = BlockExecStats.builder() @@ -74,7 +74,7 @@ def reduce( @staticmethod def _prune_unused_columns( block: Block, - key: str, + key: Union[str, List[str]], aggs: Tuple[AggregateFn], ) -> Block: """Prune unused columns from block before aggregate.""" @@ -83,6 +83,8 @@ def _prune_unused_columns( if isinstance(key, str): columns.add(key) + elif isinstance(key, list): + columns.update(key) elif callable(key): prune_columns = False diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 5dbc68b38e2e..f95e56f9eb83 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -1833,7 +1833,7 @@ def union(self, *other: List["Dataset"]) -> "Dataset": logical_plan, ) - def groupby(self, key: Optional[str]) -> "GroupedData": + def groupby(self, key: Union[str, List[str], None]) -> "GroupedData": """Group rows of a :class:`Dataset` according to a column. Use this method to transform data based on a @@ -1860,7 +1860,8 @@ def normalize_variety(group: pd.DataFrame) -> pd.DataFrame: Time complexity: O(dataset size * log(dataset size / parallelism)) Args: - key: A column name. If this is ``None``, place all rows in a single group. + key: A column name or list of column names. + If this is ``None``, place all rows in a single group. Returns: A lazy :class:`~ray.data.grouped_data.GroupedData`. diff --git a/python/ray/data/grouped_data.py b/python/ray/data/grouped_data.py index 1eddf7e8d9c1..aec2192fe0f7 100644 --- a/python/ray/data/grouped_data.py +++ b/python/ray/data/grouped_data.py @@ -114,7 +114,7 @@ class GroupedData: The actual groupby is deferred until an aggregation is applied. """ - def __init__(self, dataset: Dataset, key: str): + def __init__(self, dataset: Dataset, key: Union[str, List[str]]): """Construct a dataset grouped by key (internal API). The constructor is not part of the GroupedData API. diff --git a/python/ray/data/tests/test_all_to_all.py b/python/ray/data/tests/test_all_to_all.py index 89512c1bc45f..424fe63be765 100644 --- a/python/ray/data/tests/test_all_to_all.py +++ b/python/ray/data/tests/test_all_to_all.py @@ -16,6 +16,8 @@ from ray.data.tests.util import column_udf, named_values from ray.tests.conftest import * # noqa +RANDOM_SEED = 123 + def test_zip(ray_start_regular_shared): ds1 = ray.data.range(5, parallelism=5) @@ -331,6 +333,34 @@ def _to_pandas(ds): ] +@pytest.mark.parametrize("num_parts", [1, 30]) +@pytest.mark.parametrize("ds_format", ["pyarrow", "pandas"]) +def test_groupby_multiple_keys_tabular_count( + ray_start_regular_shared, ds_format, num_parts, use_push_based_shuffle +): + # Test built-in count aggregation + print(f"Seeding RNG for test_groupby_arrow_count with: {RANDOM_SEED}") + random.seed(RANDOM_SEED) + xs = list(range(100)) + random.shuffle(xs) + + ds = ray.data.from_items([{"A": (x % 2), "B": (x % 3)} for x in xs]).repartition( + num_parts + ) + ds = ds.map_batches(lambda x: x, batch_size=None, batch_format=ds_format) + + agg_ds = ds.groupby(["A", "B"]).count() + assert agg_ds.count() == 6 + assert list(agg_ds.sort(["A", "B"]).iter_rows()) == [ + {"A": 0, "B": 0, "count()": 17}, + {"A": 0, "B": 1, "count()": 16}, + {"A": 0, "B": 2, "count()": 17}, + {"A": 1, "B": 0, "count()": 17}, + {"A": 1, "B": 1, "count()": 17}, + {"A": 1, "B": 2, "count()": 16}, + ] + + @pytest.mark.parametrize("num_parts", [1, 30]) @pytest.mark.parametrize("ds_format", ["arrow", "pandas"]) def test_groupby_tabular_sum(