diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index c14fb1dd0da..8e800456c20 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -310,20 +310,25 @@ def split_by_worker( constructor = df._constructor_sliced assert isinstance(constructor, type) worker_for = constructor(worker_for) - df = df.merge( + df["_worker"] = df[[column]].reset_index(drop=True).merge( right=worker_for.cat.codes.rename("_worker"), left_on=column, right_index=True, - how="inner", - ) + ).sort_index()["_worker"] + # df = df.merge( + # right=worker_for.cat.codes.rename("_worker"), + # left_on=column, + # right_index=True, + # how="inner", + # ) nrows = len(df) if not nrows: return {} # assert len(df) == nrows # Not true if some outputs aren't wanted # FIXME: If we do not preserve the index something is corrupting the # bytestream such that it cannot be deserialized anymore - t = to_pyarrow_table_dispatch(df, preserve_index=True) - t = t.sort_by("_worker") + t = to_pyarrow_table_dispatch(df.sort_values("_worker"), preserve_index=True) + #t = t.sort_by("_worker") codes = np.asarray(t["_worker"]) t = t.drop(["_worker"]) del df @@ -346,6 +351,94 @@ def split_by_worker( return out +# def split_by_worker( +# df: pd.DataFrame, +# column: str, +# meta: pd.DataFrame, +# worker_for: pd.Series, +# ) -> dict[Any, pa.Table]: +# """ +# Split data into many arrow batches, partitioned by destination worker +# """ +# import numpy as np + +# from dask.dataframe.dispatch import to_pyarrow_table_dispatch + +# # (cudf support) Avoid pd.Series +# constructor = df._constructor_sliced +# assert isinstance(constructor, type) +# worker_for = constructor(worker_for) + +# df["_worker"] = df[[column]].reset_index(drop=True).merge( +# right=worker_for.cat.codes.rename("_worker"), +# left_on=column, +# right_index=True, +# ).sort_index()["_worker"] +# # df = df.merge( +# # right=worker_for.cat.codes.rename("_worker"), +# # left_on=column, +# # right_index=True, +# # how="inner", +# # ) +# nrows = len(df) +# if not nrows: +# return {} + +# c = df["_worker"] +# k = len(worker_for.cat.categories) +# out = { +# worker_for.cat.categories[code] : to_pyarrow_table_dispatch(shard, preserve_index=True) +# for code, shard in enumerate( +# df.scatter_by_map( +# c.astype(np.int32, copy=False), +# map_size=k, +# keep_index=True, +# ) +# ) +# } +# assert sum(map(len, out.values())) == nrows +# return out + +# # # (cudf support) Avoid pd.Series +# # constructor = df._constructor_sliced +# # assert isinstance(constructor, type) +# # worker_for = constructor(worker_for) +# # df = df.merge( +# # right=worker_for.cat.codes.rename("_worker"), +# # left_on=column, +# # right_index=True, +# # how="inner", +# # ) +# # nrows = len(df) +# # if not nrows: +# # return {} +# # # assert len(df) == nrows # Not true if some outputs aren't wanted +# # # FIXME: If we do not preserve the index something is corrupting the +# # # bytestream such that it cannot be deserialized anymore +# # t = to_pyarrow_table_dispatch(df, preserve_index=True) +# # t = t.sort_by("_worker") +# # codes = np.asarray(t["_worker"]) +# # t = t.drop(["_worker"]) +# # del df + +# # splits = np.where(codes[1:] != codes[:-1])[0] + 1 +# # splits = np.concatenate([[0], splits]) + +# # shards = [ +# # t.slice(offset=a, length=b - a) for a, b in toolz.sliding_window(2, splits) +# # ] +# # shards.append(t.slice(offset=splits[-1], length=None)) + +# # unique_codes = codes[splits] +# # out = { +# # # FIXME https://github.com/pandas-dev/pandas-stubs/issues/43 +# # worker_for.cat.categories[code]: shard +# # for code, shard in zip(unique_codes, shards) +# # } +# # assert sum(map(len, out.values())) == nrows +# # return out + + def split_by_partition(t: pa.Table, column: str) -> dict[int, pa.Table]: """ Split data into many arrow batches, partitioned by final partition