Skip to content

Commit

Permalink
play with split_by_worker
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Dec 12, 2023
1 parent f16c51d commit 166eb51
Showing 1 changed file with 98 additions and 5 deletions.
103 changes: 98 additions & 5 deletions distributed/shuffle/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,20 +310,25 @@ def split_by_worker(
constructor = df._constructor_sliced
assert isinstance(constructor, type)
worker_for = constructor(worker_for)
df = df.merge(
df["_worker"] = df[[column]].reset_index(drop=True).merge(
right=worker_for.cat.codes.rename("_worker"),
left_on=column,
right_index=True,
how="inner",
)
).sort_index()["_worker"]
# df = df.merge(
# right=worker_for.cat.codes.rename("_worker"),
# left_on=column,
# right_index=True,
# how="inner",
# )
nrows = len(df)
if not nrows:
return {}
# assert len(df) == nrows # Not true if some outputs aren't wanted
# FIXME: If we do not preserve the index something is corrupting the
# bytestream such that it cannot be deserialized anymore
t = to_pyarrow_table_dispatch(df, preserve_index=True)
t = t.sort_by("_worker")
t = to_pyarrow_table_dispatch(df.sort_values("_worker"), preserve_index=True)
#t = t.sort_by("_worker")
codes = np.asarray(t["_worker"])
t = t.drop(["_worker"])
del df
Expand All @@ -346,6 +351,94 @@ def split_by_worker(
return out


# def split_by_worker(
# df: pd.DataFrame,
# column: str,
# meta: pd.DataFrame,
# worker_for: pd.Series,
# ) -> dict[Any, pa.Table]:
# """
# Split data into many arrow batches, partitioned by destination worker
# """
# import numpy as np

# from dask.dataframe.dispatch import to_pyarrow_table_dispatch

# # (cudf support) Avoid pd.Series
# constructor = df._constructor_sliced
# assert isinstance(constructor, type)
# worker_for = constructor(worker_for)

# df["_worker"] = df[[column]].reset_index(drop=True).merge(
# right=worker_for.cat.codes.rename("_worker"),
# left_on=column,
# right_index=True,
# ).sort_index()["_worker"]
# # df = df.merge(
# # right=worker_for.cat.codes.rename("_worker"),
# # left_on=column,
# # right_index=True,
# # how="inner",
# # )
# nrows = len(df)
# if not nrows:
# return {}

# c = df["_worker"]
# k = len(worker_for.cat.categories)
# out = {
# worker_for.cat.categories[code] : to_pyarrow_table_dispatch(shard, preserve_index=True)
# for code, shard in enumerate(
# df.scatter_by_map(
# c.astype(np.int32, copy=False),
# map_size=k,
# keep_index=True,
# )
# )
# }
# assert sum(map(len, out.values())) == nrows
# return out

# # # (cudf support) Avoid pd.Series
# # constructor = df._constructor_sliced
# # assert isinstance(constructor, type)
# # worker_for = constructor(worker_for)
# # df = df.merge(
# # right=worker_for.cat.codes.rename("_worker"),
# # left_on=column,
# # right_index=True,
# # how="inner",
# # )
# # nrows = len(df)
# # if not nrows:
# # return {}
# # # assert len(df) == nrows # Not true if some outputs aren't wanted
# # # FIXME: If we do not preserve the index something is corrupting the
# # # bytestream such that it cannot be deserialized anymore
# # t = to_pyarrow_table_dispatch(df, preserve_index=True)
# # t = t.sort_by("_worker")
# # codes = np.asarray(t["_worker"])
# # t = t.drop(["_worker"])
# # del df

# # splits = np.where(codes[1:] != codes[:-1])[0] + 1
# # splits = np.concatenate([[0], splits])

# # shards = [
# # t.slice(offset=a, length=b - a) for a, b in toolz.sliding_window(2, splits)
# # ]
# # shards.append(t.slice(offset=splits[-1], length=None))

# # unique_codes = codes[splits]
# # out = {
# # # FIXME https://github.com/pandas-dev/pandas-stubs/issues/43
# # worker_for.cat.categories[code]: shard
# # for code, shard in zip(unique_codes, shards)
# # }
# # assert sum(map(len, out.values())) == nrows
# # return out


def split_by_partition(t: pa.Table, column: str) -> dict[int, pa.Table]:
"""
Split data into many arrow batches, partitioned by final partition
Expand Down

0 comments on commit 166eb51

Please sign in to comment.