Skip to content

Commit

Permalink
Merge branch 'branch-23.10' into plugin-inherit
Browse files Browse the repository at this point in the history
  • Loading branch information
pentschev authored Sep 28, 2023
2 parents c8e0982 + f98963d commit af9c522
Show file tree
Hide file tree
Showing 13 changed files with 217 additions and 61 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
arch: "amd64"
branch: ${{ inputs.branch }}
build_type: ${{ inputs.build_type || 'branch' }}
container_image: "rapidsai/ci:latest"
container_image: "rapidsai/ci-conda:latest"
date: ${{ inputs.date }}
node_type: "gpu-v100-latest-1"
run_script: "ci/build_docs.sh"
Expand All @@ -60,7 +60,7 @@ jobs:
wheel-build:
runs-on: ubuntu-latest
container:
image: rapidsai/ci:latest
image: rapidsai/ci-conda:latest
defaults:
run:
shell: bash
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ jobs:
build_type: pull-request
node_type: "gpu-v100-latest-1"
arch: "amd64"
container_image: "rapidsai/ci:latest"
container_image: "rapidsai/ci-conda:latest"
run_script: "ci/build_docs.sh"
wheel-build:
needs: checks
runs-on: ubuntu-latest
container:
image: rapidsai/ci:latest
image: rapidsai/ci-conda:latest
defaults:
run:
shell: bash
Expand Down
2 changes: 1 addition & 1 deletion ci/build_python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ rapids-print-env

rapids-logger "Begin py build"

rapids-mamba-retry mambabuild \
rapids-conda-retry mambabuild \
conda/recipes/dask-cuda

rapids-upload-conda-to-s3 python
2 changes: 1 addition & 1 deletion conda/recipes/dask-cuda/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ requirements:
- tomli
run:
- python
- dask-core >=2023.7.1
- dask-core ==2023.9.2
{% for r in data.get("project", {}).get("dependencies", []) %}
- {{ r }}
{% endfor %}
Expand Down
1 change: 1 addition & 0 deletions dask_cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

__version__ = "23.10.00"

from . import compat

# Monkey patching Dask to make use of explicit-comms when `DASK_EXPLICIT_COMMS=True`
dask.dataframe.shuffle.rearrange_by_column = get_rearrange_by_column_wrapper(
Expand Down
118 changes: 118 additions & 0 deletions dask_cuda/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import pickle

import msgpack
from packaging.version import Version

import dask
import distributed
import distributed.comm.utils
import distributed.protocol
from distributed.comm.utils import OFFLOAD_THRESHOLD, nbytes, offload
from distributed.protocol.core import (
Serialized,
decompress,
logger,
merge_and_deserialize,
msgpack_decode_default,
msgpack_opts,
)

if Version(distributed.__version__) >= Version("2023.8.1"):
# Monkey-patch protocol.core.loads (and its users)
async def from_frames(
frames, deserialize=True, deserializers=None, allow_offload=True
):
"""
Unserialize a list of Distributed protocol frames.
"""
size = False

def _from_frames():
try:
# Patched code
return loads(
frames, deserialize=deserialize, deserializers=deserializers
)
# end patched code
except EOFError:
if size > 1000:
datastr = "[too large to display]"
else:
datastr = frames
# Aid diagnosing
logger.error("truncated data stream (%d bytes): %s", size, datastr)
raise

if allow_offload and deserialize and OFFLOAD_THRESHOLD:
size = sum(map(nbytes, frames))
if (
allow_offload
and deserialize
and OFFLOAD_THRESHOLD
and size > OFFLOAD_THRESHOLD
):
res = await offload(_from_frames)
else:
res = _from_frames()

return res

def loads(frames, deserialize=True, deserializers=None):
"""Transform bytestream back into Python value"""

allow_pickle = dask.config.get("distributed.scheduler.pickle")

try:

def _decode_default(obj):
offset = obj.get("__Serialized__", 0)
if offset > 0:
sub_header = msgpack.loads(
frames[offset],
object_hook=msgpack_decode_default,
use_list=False,
**msgpack_opts,
)
offset += 1
sub_frames = frames[offset : offset + sub_header["num-sub-frames"]]
if deserialize:
if "compression" in sub_header:
sub_frames = decompress(sub_header, sub_frames)
return merge_and_deserialize(
sub_header, sub_frames, deserializers=deserializers
)
else:
return Serialized(sub_header, sub_frames)

offset = obj.get("__Pickled__", 0)
if offset > 0:
sub_header = msgpack.loads(frames[offset])
offset += 1
sub_frames = frames[offset : offset + sub_header["num-sub-frames"]]
# Patched code
if "compression" in sub_header:
sub_frames = decompress(sub_header, sub_frames)
# end patched code
if allow_pickle:
return pickle.loads(
sub_header["pickled-obj"], buffers=sub_frames
)
else:
raise ValueError(
"Unpickle on the Scheduler isn't allowed, "
"set `distributed.scheduler.pickle=true`"
)

return msgpack_decode_default(obj)

return msgpack.loads(
frames[0], object_hook=_decode_default, use_list=False, **msgpack_opts
)

except Exception:
logger.critical("Failed to deserialize", exc_info=True)
raise

distributed.protocol.loads = loads
distributed.protocol.core.loads = loads
distributed.comm.utils.from_frames = from_frames
26 changes: 14 additions & 12 deletions dask_cuda/explicit_comms/dataframe/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ async def shuffle_task(
ignore_index: bool,
num_rounds: int,
batchsize: int,
) -> List[DataFrame]:
) -> Dict[int, DataFrame]:
"""Explicit-comms shuffle task
This function is running on each worker participating in the shuffle.
Expand Down Expand Up @@ -360,8 +360,8 @@ async def shuffle_task(
Returns
-------
partitions: list of DataFrames
List of dataframe-partitions
partitions: dict
dict that maps each Partition ID to a dataframe-partition
"""

proxify = get_proxify(s["worker"])
Expand All @@ -387,14 +387,13 @@ async def shuffle_task(
)

# Finally, we concatenate the output dataframes into the final output partitions
ret = []
ret = {}
while out_part_id_to_dataframe_list:
ret.append(
proxify(
dd_concat(
out_part_id_to_dataframe_list.popitem()[1],
ignore_index=ignore_index,
)
part_id, dataframe_list = out_part_id_to_dataframe_list.popitem()
ret[part_id] = proxify(
dd_concat(
dataframe_list,
ignore_index=ignore_index,
)
)
# For robustness, we yield this task to give Dask a chance to do bookkeeping
Expand Down Expand Up @@ -529,9 +528,12 @@ def shuffle(

dsk = {}
for rank in ranks:
for i, part_id in enumerate(rank_to_out_part_ids[rank]):
for part_id in rank_to_out_part_ids[rank]:
dsk[(name, part_id)] = c.client.submit(
getitem, shuffle_result[rank], i, workers=[c.worker_addresses[rank]]
getitem,
shuffle_result[rank],
part_id,
workers=[c.worker_addresses[rank]],
)

# Create a distributed Dataframe from all the pieces
Expand Down
43 changes: 29 additions & 14 deletions dask_cuda/tests/test_explicit_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def check_partitions(df, npartitions):
return True


def _test_dataframe_shuffle(backend, protocol, n_workers):
def _test_dataframe_shuffle(backend, protocol, n_workers, _partitions):
if backend == "cudf":
cudf = pytest.importorskip("cudf")

Expand All @@ -112,6 +112,9 @@ def _test_dataframe_shuffle(backend, protocol, n_workers):
if backend == "cudf":
df = cudf.DataFrame.from_pandas(df)

if _partitions:
df["_partitions"] = 0

for input_nparts in range(1, 5):
for output_nparts in range(1, 5):
ddf = dd.from_pandas(df.copy(), npartitions=input_nparts).persist(
Expand All @@ -123,33 +126,45 @@ def _test_dataframe_shuffle(backend, protocol, n_workers):
with dask.config.set(explicit_comms_batchsize=batchsize):
ddf = explicit_comms_shuffle(
ddf,
["key"],
["_partitions"] if _partitions else ["key"],
npartitions=output_nparts,
batchsize=batchsize,
).persist()

assert ddf.npartitions == output_nparts

# Check that each partition hashes to the same value
result = ddf.map_partitions(
check_partitions, output_nparts
).compute()
assert all(result.to_list())

# Check the values (ignoring the row order)
expected = df.sort_values("key")
got = ddf.compute().sort_values("key")
assert_eq(got, expected)
if _partitions:
# If "_partitions" is the hash key, we expect all but
# the first partition to be empty
assert_eq(ddf.partitions[0].compute(), df)
assert all(
len(ddf.partitions[i].compute()) == 0
for i in range(1, ddf.npartitions)
)
else:
# Check that each partition hashes to the same value
result = ddf.map_partitions(
check_partitions, output_nparts
).compute()
assert all(result.to_list())

# Check the values (ignoring the row order)
expected = df.sort_values("key")
got = ddf.compute().sort_values("key")
assert_eq(got, expected)


@pytest.mark.parametrize("nworkers", [1, 2, 3])
@pytest.mark.parametrize("backend", ["pandas", "cudf"])
@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
def test_dataframe_shuffle(backend, protocol, nworkers):
@pytest.mark.parametrize("_partitions", [True, False])
def test_dataframe_shuffle(backend, protocol, nworkers, _partitions):
if backend == "cudf":
pytest.importorskip("cudf")

p = mp.Process(target=_test_dataframe_shuffle, args=(backend, protocol, nworkers))
p = mp.Process(
target=_test_dataframe_shuffle, args=(backend, protocol, nworkers, _partitions)
)
p.start()
p.join()
assert not p.exitcode
Expand Down
18 changes: 18 additions & 0 deletions dask_cuda/tests/test_from_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest

import dask.array as da
from distributed import Client

from dask_cuda import LocalCUDACluster

pytest.importorskip("ucp")
cupy = pytest.importorskip("cupy")


@pytest.mark.parametrize("protocol", ["ucx", "tcp"])
def test_ucx_from_array(protocol):
N = 10_000
with LocalCUDACluster(protocol=protocol) as cluster:
with Client(cluster):
val = da.from_array(cupy.arange(N), chunks=(N // 10,)).sum().compute()
assert val == (N * (N - 1)) // 2
2 changes: 1 addition & 1 deletion dask_cuda/tests/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def _pxy_deserialize(self):

@pytest.mark.parametrize("send_serializers", [None, ("dask", "pickle"), ("cuda",)])
@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
@gen_test(timeout=60)
@gen_test(timeout=120)
async def test_communicating_proxy_objects(protocol, send_serializers):
"""Testing serialization of cuDF dataframe when communicating"""
cudf = pytest.importorskip("cudf")
Expand Down
Loading

0 comments on commit af9c522

Please sign in to comment.