Skip to content

Commit

Permalink
[Data] Enable group over multiple keys in datasets (ray-project#37832)
Browse files Browse the repository at this point in the history
This is a first draft into enabling grouping Ray datasets over multiple keys (i.e. passing a list of keys to `.group()`).

A new unit test `test_groupby_multiple_keys_tabular_count` showcases an example.

---------

Signed-off-by: Abdel Jaidi <[email protected]>
Signed-off-by: Anton Kukushkin <[email protected]>
Co-authored-by: Abdel Jaidi <[email protected]>
Co-authored-by: Anton Kukushkin <[email protected]>
Co-authored-by: Hao Chen <[email protected]>
  • Loading branch information
4 people authored Oct 27, 2023
1 parent 187c5c5 commit 310409f
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 62 deletions.
96 changes: 66 additions & 30 deletions python/ray/data/_internal/arrow_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,31 +72,51 @@ class ArrowRow(TableRow):
Row of a tabular Dataset backed by a Arrow Table block.
"""

def __getitem__(self, key: str) -> Any:
def __getitem__(self, key: Union[str, List[str]]) -> Any:
from ray.data.extensions.tensor_extension import (
ArrowTensorType,
ArrowVariableShapedTensorType,
)

schema = self._row.schema
if isinstance(
schema.field(key).type,
(ArrowTensorType, ArrowVariableShapedTensorType),
):
# Build a tensor row.
return ArrowBlockAccessor._build_tensor_row(self._row, col_name=key)
def get_item(keys: List[str]) -> Any:
schema = self._row.schema
if isinstance(
schema.field(keys[0]).type,
(ArrowTensorType, ArrowVariableShapedTensorType),
):
# Build a tensor row.
return tuple(
[
ArrowBlockAccessor._build_tensor_row(self._row, col_name=key)
for key in keys
]
)

table = self._row.select(keys)
if len(table) == 0:
return None

items = [col[0] for col in table.columns]
try:
# Try to interpret this as a pyarrow.Scalar value.
return tuple([item.as_py() for item in items])

except AttributeError:
# Assume that this row is an element of an extension array, and
# that it is bypassing pyarrow's scalar model for Arrow < 8.0.0.
return items

col = self._row[key]
if len(col) == 0:
is_single_item = isinstance(key, str)
keys = [key] if is_single_item else key

items = get_item(keys)

if items is None:
return None
item = col[0]
try:
# Try to interpret this as a pyarrow.Scalar value.
return item.as_py()
except AttributeError:
# Assume that this row is an element of an extension array, and
# that it is bypassing pyarrow's scalar model for Arrow < 8.0.0.
return item
elif is_single_item:
return items[0]
else:
return items

def __iter__(self) -> Iterator:
for k in self._row.column_names:
Expand Down Expand Up @@ -428,13 +448,15 @@ def sort_and_partition(

return find_partitions(table, boundaries, sort_key)

def combine(self, key: str, aggs: Tuple["AggregateFn"]) -> Block:
def combine(self, key: Union[str, List[str]], aggs: Tuple["AggregateFn"]) -> Block:
"""Combine rows with the same key into an accumulator.
This assumes the block is already sorted by key in ascending order.
Args:
key: The column name of key or None for global aggregation.
key: A column name or list of column names.
If this is ``None``, place all rows in a single group.
aggs: The aggregations to do.
Returns:
Expand All @@ -443,9 +465,10 @@ def combine(self, key: str, aggs: Tuple["AggregateFn"]) -> Block:
aggregation.
If key is None then the k column is omitted.
"""
if key is not None and not isinstance(key, str):
if key is not None and not isinstance(key, (str, list)):
raise ValueError(
"key must be a string or None when aggregating on Arrow blocks, but "
"key must be a string, list of strings or None when aggregating "
"on Arrow blocks, but "
f"got: {type(key)}."
)

Expand Down Expand Up @@ -486,7 +509,15 @@ def iter_groups() -> Iterator[Tuple[KeyType, Block]]:
# Build the row.
row = {}
if key is not None:
row[key] = group_key
if isinstance(key, list):
keys = key
group_keys = group_key
else:
keys = [key]
group_keys = [group_key]

for k, gk in zip(keys, group_keys):
row[k] = gk

count = collections.defaultdict(int)
for agg, accumulator in zip(aggs, accumulators):
Expand Down Expand Up @@ -521,7 +552,7 @@ def merge_sorted_blocks(
@staticmethod
def aggregate_combined_blocks(
blocks: List[Block],
key: str,
key: Union[str, List[str]],
aggs: Tuple["AggregateFn"],
finalize: bool,
) -> Tuple[Block, BlockMetadata]:
Expand All @@ -546,8 +577,12 @@ def aggregate_combined_blocks(
"""

stats = BlockExecStats.builder()

keys = key if isinstance(key, list) else [key]
key_fn = (
(lambda r: r[r._row.schema.names[0]]) if key is not None else (lambda r: 0)
(lambda r: tuple(r[r._row.schema.names[: len(keys)]]))
if key is not None
else (lambda r: (0,))
)

iter = heapq.merge(
Expand All @@ -563,15 +598,15 @@ def aggregate_combined_blocks(
try:
if next_row is None:
next_row = next(iter)
next_key = key_fn(next_row)
next_key_name = (
next_row._row.schema.names[0] if key is not None else None
next_keys = key_fn(next_row)
next_key_names = (
next_row._row.schema.names[: len(keys)] if key is not None else None
)

def gen():
nonlocal iter
nonlocal next_row
while key_fn(next_row) == next_key:
while key_fn(next_row) == next_keys:
yield next_row
try:
next_row = next(iter)
Expand Down Expand Up @@ -606,7 +641,8 @@ def gen():
# Build the row.
row = {}
if key is not None:
row[next_key_name] = next_key
for next_key, next_key_name in zip(next_keys, next_key_names):
row[next_key_name] = next_key

for agg, agg_name, accumulator in zip(
aggs, resolved_agg_names, accumulators
Expand Down
91 changes: 64 additions & 27 deletions python/ray/data/_internal/pandas_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,24 +55,40 @@ class PandasRow(TableRow):
Row of a tabular Dataset backed by a Pandas DataFrame block.
"""

def __getitem__(self, key: str) -> Any:
def __getitem__(self, key: Union[str, List[str]]) -> Any:
from ray.data.extensions import TensorArrayElement

col = self._row[key]
if len(col) == 0:
def get_item(keys: List[str]) -> Any:
col = self._row[keys]
if len(col) == 0:
return None

items = col.iloc[0]
if isinstance(items[0], TensorArrayElement):
# Getting an item in a Pandas tensor column may return
# a TensorArrayElement, which we have to convert to an ndarray.
return tuple([item.to_numpy() for item in items])

try:
# Try to interpret this as a numpy-type value.
# See https://stackoverflow.com/questions/9452775/converting-numpy-dtypes-to-native-python-types. # noqa: E501
return tuple([item.as_py() for item in items])

except (AttributeError, ValueError):
# Fallback to the original form.
return items

is_single_item = isinstance(key, str)
keys = [key] if is_single_item else key

items = get_item(keys)

if items is None:
return None
item = col.iloc[0]
if isinstance(item, TensorArrayElement):
# Getting an item in a Pandas tensor column may return a TensorArrayElement,
# which we have to convert to an ndarray.
item = item.to_numpy()
try:
# Try to interpret this as a numpy-type value.
# See https://stackoverflow.com/questions/9452775/converting-numpy-dtypes-to-native-python-types. # noqa: E501
return item.item()
except (AttributeError, ValueError):
# Fallback to the original form.
return item
elif is_single_item:
return items[0]
else:
return items

def __iter__(self) -> Iterator:
for k in self._row.columns:
Expand Down Expand Up @@ -359,13 +375,17 @@ def sort_and_partition(

return find_partitions(table, boundaries, sort_key)

def combine(self, key: str, aggs: Tuple["AggregateFn"]) -> "pandas.DataFrame":
def combine(
self, key: Union[str, List[str]], aggs: Tuple["AggregateFn"]
) -> "pandas.DataFrame":
"""Combine rows with the same key into an accumulator.
This assumes the block is already sorted by key in ascending order.
Args:
key: The column name of key or None for global aggregation.
key: A column name or list of column names.
If this is ``None``, place all rows in a single group.
aggs: The aggregations to do.
Returns:
Expand All @@ -374,9 +394,10 @@ def combine(self, key: str, aggs: Tuple["AggregateFn"]) -> "pandas.DataFrame":
aggregation.
If key is None then the k column is omitted.
"""
if key is not None and not isinstance(key, str):
if key is not None and not isinstance(key, (str, list)):
raise ValueError(
"key must be a string or None when aggregating on Pandas blocks, but "
"key must be a string, list of strings or None when aggregating "
"on Pandas blocks, but "
f"got: {type(key)}."
)

Expand All @@ -395,7 +416,7 @@ def iter_groups() -> Iterator[Tuple[KeyType, Block]]:
if next_row is None:
next_row = next(iter)
next_key = next_row[key]
while next_row[key] == next_key:
while np.all(next_row[key] == next_key):
end += 1
try:
next_row = next(iter)
Expand All @@ -417,7 +438,15 @@ def iter_groups() -> Iterator[Tuple[KeyType, Block]]:
# Build the row.
row = {}
if key is not None:
row[key] = group_key
if isinstance(key, list):
keys = key
group_keys = group_key
else:
keys = [key]
group_keys = [group_key]

for k, gk in zip(keys, group_keys):
row[k] = gk

count = collections.defaultdict(int)
for agg, accumulator in zip(aggs, accumulators):
Expand Down Expand Up @@ -452,7 +481,7 @@ def merge_sorted_blocks(
@staticmethod
def aggregate_combined_blocks(
blocks: List["pandas.DataFrame"],
key: str,
key: Union[str, List[str]],
aggs: Tuple["AggregateFn"],
finalize: bool,
) -> Tuple["pandas.DataFrame", BlockMetadata]:
Expand All @@ -477,7 +506,12 @@ def aggregate_combined_blocks(
"""

stats = BlockExecStats.builder()
key_fn = (lambda r: r[r._row.columns[0]]) if key is not None else (lambda r: 0)
keys = key if isinstance(key, list) else [key]
key_fn = (
(lambda r: tuple(r[r._row.columns[: len(keys)]]))
if key is not None
else (lambda r: (0,))
)

iter = heapq.merge(
*[
Expand All @@ -492,13 +526,15 @@ def aggregate_combined_blocks(
try:
if next_row is None:
next_row = next(iter)
next_key = key_fn(next_row)
next_key_name = next_row._row.columns[0] if key is not None else None
next_keys = key_fn(next_row)
next_key_names = (
next_row._row.columns[: len(keys)] if key is not None else None
)

def gen():
nonlocal iter
nonlocal next_row
while key_fn(next_row) == next_key:
while key_fn(next_row) == next_keys:
yield next_row
try:
next_row = next(iter)
Expand Down Expand Up @@ -533,7 +569,8 @@ def gen():
# Build the row.
row = {}
if key is not None:
row[next_key_name] = next_key
for next_key, next_key_name in zip(next_keys, next_key_names):
row[next_key_name] = next_key

for agg, agg_name, accumulator in zip(
aggs, resolved_agg_names, accumulators
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def map(
block: Block,
output_num_blocks: int,
boundaries: List[KeyType],
key: Optional[str],
key: Union[str, List[str], None],
aggs: List[AggregateFn],
) -> List[Union[BlockMetadata, Block]]:
stats = BlockExecStats.builder()
Expand Down Expand Up @@ -74,7 +74,7 @@ def reduce(
@staticmethod
def _prune_unused_columns(
block: Block,
key: str,
key: Union[str, List[str]],
aggs: Tuple[AggregateFn],
) -> Block:
"""Prune unused columns from block before aggregate."""
Expand All @@ -83,6 +83,8 @@ def _prune_unused_columns(

if isinstance(key, str):
columns.add(key)
elif isinstance(key, list):
columns.update(key)
elif callable(key):
prune_columns = False

Expand Down
5 changes: 3 additions & 2 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1833,7 +1833,7 @@ def union(self, *other: List["Dataset"]) -> "Dataset":
logical_plan,
)

def groupby(self, key: Optional[str]) -> "GroupedData":
def groupby(self, key: Union[str, List[str], None]) -> "GroupedData":
"""Group rows of a :class:`Dataset` according to a column.
Use this method to transform data based on a
Expand All @@ -1860,7 +1860,8 @@ def normalize_variety(group: pd.DataFrame) -> pd.DataFrame:
Time complexity: O(dataset size * log(dataset size / parallelism))
Args:
key: A column name. If this is ``None``, place all rows in a single group.
key: A column name or list of column names.
If this is ``None``, place all rows in a single group.
Returns:
A lazy :class:`~ray.data.grouped_data.GroupedData`.
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/grouped_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class GroupedData:
The actual groupby is deferred until an aggregation is applied.
"""

def __init__(self, dataset: Dataset, key: str):
def __init__(self, dataset: Dataset, key: Union[str, List[str]]):
"""Construct a dataset grouped by key (internal API).
The constructor is not part of the GroupedData API.
Expand Down
Loading

0 comments on commit 310409f

Please sign in to comment.