Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Fix cuda storage transfer deadlock on multiple GPUs #788

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
4048a5b
REF: refactor transfer to avoid deadlocks
UranusSeven Jun 1, 2023
be78817
Debugging
UranusSeven Jun 2, 2023
640864a
Debugging
UranusSeven Jun 2, 2023
6ab37e1
Debugging
UranusSeven Jun 2, 2023
ad42bc8
Fix
UranusSeven Jun 2, 2023
f434eab
Debugging
UranusSeven Jun 2, 2023
4d80dbb
Fix
UranusSeven Jun 13, 2023
30f58a7
Fix
UranusSeven Jun 14, 2023
6c13816
checkpoint
UranusSeven Jun 15, 2023
8fea60f
fix
ChengjieLi28 Aug 7, 2023
5dafe70
REF: refactor transfer to avoid deadlocks
UranusSeven Jun 1, 2023
89ad935
Debugging
UranusSeven Jun 2, 2023
dd63a1c
Debugging
UranusSeven Jun 2, 2023
df23a15
Debugging
UranusSeven Jun 2, 2023
072848f
Fix
UranusSeven Jun 2, 2023
373e7bc
Debugging
UranusSeven Jun 2, 2023
aa264be
Fix
UranusSeven Jun 13, 2023
010871d
Fix
UranusSeven Jun 14, 2023
4a629ab
checkpoint
UranusSeven Jun 15, 2023
faa2096
fix
ChengjieLi28 Aug 7, 2023
c485cb1
Merge branch 'ref/transfer' of github.com:UranusSeven/xorbits into st…
luweizheng Jun 24, 2024
4026b4c
test
luweizheng Jul 5, 2024
27a8333
remove debug log
luweizheng Jul 5, 2024
77cf425
Merge branch 'main' into feat/cudf
mergify[bot] Jul 6, 2024
c26c46d
delete logs
luweizheng Jul 6, 2024
7ceaae8
Merge branch 'feat/cudf' of github.com:luweizheng/xorbits into feat/cudf
luweizheng Jul 7, 2024
9ba9de3
debug log
luweizheng Jul 7, 2024
65b37e9
Merge branch 'main' into feat/cudf
mergify[bot] Jul 12, 2024
1433869
merge main
luweizheng Aug 23, 2024
af2fe83
fix test
luweizheng Aug 23, 2024
4a059f4
Merge branch 'main' into feat/cudf
luweizheng Sep 4, 2024
44fa88a
Merge branch 'main' into feat/cudf
mergify[bot] Sep 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -728,8 +728,7 @@ def _disallow_combine_and_agg(ctx, op):
pd.testing.assert_frame_equal(result.sort_index(), raw.groupby("c1").agg("sum"))


@pytest.mark.skip_ray_dag # _fetch_infos() is not supported by ray backend.
def test_distributed_groupby_agg(setup_cluster):
def test_distributed_groupby_agg(setup):
rs = np.random.RandomState(0)
raw = pd.DataFrame(rs.rand(50000, 10))
df = md.DataFrame(raw, chunk_size=raw.shape[0] // 2)
Expand Down
166 changes: 149 additions & 17 deletions python/xorbits/_mars/services/storage/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
self._quota_refs = quota_refs
self._band_name = band_name
self._supervisor_address = None
self._lock = asyncio.Lock()

@classmethod
def gen_uid(cls, band_name: str):
Expand Down Expand Up @@ -292,6 +293,7 @@ async def delete_object(

@mo.extensible
async def delete(self, session_id: str, data_key: str, error: str = "raise"):
logger.debug("Delete %s, %s on %s", session_id, data_key, self.address)
if error not in ("raise", "ignore"): # pragma: no cover
raise ValueError("error must be raise or ignore")

Expand Down Expand Up @@ -382,6 +384,7 @@ async def batch_delete(self, args_list, kwargs_list):
await self._quota_refs[level].release_quota(size)

@mo.extensible
@mo.no_lock
async def open_reader(self, session_id: str, data_key: str) -> StorageFileObject:
data_info = await self._data_manager_ref.get_data_info(
session_id, data_key, self._band_name
Expand All @@ -390,6 +393,7 @@ async def open_reader(self, session_id: str, data_key: str) -> StorageFileObject
return reader

@open_reader.batch
@mo.no_lock
async def batch_open_readers(self, args_list, kwargs_list):
get_data_infos = []
for args, kwargs in zip(args_list, kwargs_list):
Expand Down Expand Up @@ -522,7 +526,21 @@ async def _fetch_remote(
await self._data_manager_ref.put_data_info.batch(*put_data_info_delays)
await asyncio.gather(*fetch_tasks)

async def _fetch_via_transfer(
async def get_receive_manager_ref(self, band_name: str):
from .transfer import ReceiverManagerActor

return await mo.actor_ref(
address=self.address,
uid=ReceiverManagerActor.gen_uid(band_name),
)

@staticmethod
async def get_send_manager_ref(address: str, band: str):
from .transfer import SenderManagerActor

return await mo.actor_ref(address=address, uid=SenderManagerActor.gen_uid(band))

async def fetch_via_transfer(
self,
session_id: str,
data_keys: List[Union[str, tuple]],
Expand All @@ -531,21 +549,136 @@ async def _fetch_via_transfer(
fetch_band_name: str,
error: str,
):
from .transfer import SenderManagerActor
from .transfer import ReceiverManagerActor, SenderManagerActor

logger.debug("Begin to fetch %s from band %s", data_keys, remote_band)
sender_ref: mo.ActorRefType[SenderManagerActor] = await mo.actor_ref(
address=remote_band[0], uid=SenderManagerActor.gen_uid(remote_band[1])

remote_data_manager_ref: mo.ActorRefType[DataManagerActor] = await mo.actor_ref(
address=remote_band[0], uid=DataManagerActor.default_uid()
)
await sender_ref.send_batch_data(
session_id,

logger.debug("Getting actual keys for %s", data_keys)
tasks = []
for key in data_keys:
tasks.append(remote_data_manager_ref.get_store_key.delay(session_id, key))
data_keys = await remote_data_manager_ref.get_store_key.batch(*tasks)
data_keys = list(set(data_keys))

logger.debug("Getting sub infos for %s", data_keys)
sub_infos = await remote_data_manager_ref.get_sub_infos.batch(
*[
remote_data_manager_ref.get_sub_infos.delay(session_id, key)
for key in data_keys
]
)

get_info_tasks = []
pin_tasks = []
for data_key in data_keys:
get_info_tasks.append(
remote_data_manager_ref.get_data_info.delay(
session_id, data_key, remote_band[1], error
)
)
pin_tasks.append(
remote_data_manager_ref.pin.delay(
session_id, data_key, remote_band[1], error
)
)
logger.debug("Getting data infos for %s", data_keys)
infos = await remote_data_manager_ref.get_data_info.batch(*get_info_tasks)
logger.debug("Pining %s", data_keys)
await remote_data_manager_ref.pin.batch(*pin_tasks)

filtered = [
(data_info, data_key)
for data_info, data_key in zip(infos, data_keys)
if data_info is not None
]
if filtered:
infos, data_keys = zip(*filtered)
else: # pragma: no cover
# no data to be transferred
return []
data_sizes = [info.store_size for info in infos]

if level is None:
level = infos[0].level

receiver_ref: mo.ActorRefType[
ReceiverManagerActor
] = await self.get_receive_manager_ref(fetch_band_name)

await self.request_quota_with_spill(level, sum(data_sizes))

open_writer_tasks = []
for data_key, data_size, sub_info in zip(data_keys, data_sizes, sub_infos):
open_writer_tasks.append(
self.open_writer.delay(
session_id,
data_key,
data_size,
level,
request_quota=False,
band_name=fetch_band_name,
)
)
writers = await self.open_writer.batch(*open_writer_tasks)
is_transferring_list = await receiver_ref.add_writers(
session_id, data_keys, data_sizes, sub_infos, writers, level
)

to_send_keys = []
to_wait_keys = []
wait_sizes = []
for data_key, is_transferring, _size in zip(
data_keys, is_transferring_list, data_sizes
):
if is_transferring:
to_wait_keys.append(data_key)
wait_sizes.append(_size)
else:
to_send_keys.append(data_key)

# Overapplied the quota for these wait keys, and now need to update the quota
if to_wait_keys:
self._quota_refs[level].update_quota(-sum(wait_sizes))

logger.debug(
"Start transferring %s from %s to %s",
data_keys,
self._data_manager_ref.address,
level,
fetch_band_name,
error=error,
remote_band,
(self.address, fetch_band_name),
)
logger.debug("Finish fetching %s from band %s", data_keys, remote_band)
sender_ref: mo.ActorRefType[
SenderManagerActor
] = await self.get_send_manager_ref(remote_band[0], remote_band[1])

try:
await sender_ref.send_batch_data(
session_id,
data_keys,
to_send_keys,
to_wait_keys,
(self.address, fetch_band_name),
)
await receiver_ref.handle_transmission_done(session_id, to_send_keys)
except asyncio.CancelledError:
keys_to_delete = await receiver_ref.handle_transmission_cancellation(
session_id, to_send_keys
)
for key in keys_to_delete:
await self.delete(session_id, key, error="ignore")
raise

unpin_tasks = []
for data_key in data_keys:
unpin_tasks.append(
remote_data_manager_ref.unpin.delay(
session_id, [data_key], remote_band[1], error="ignore"
)
)
await remote_data_manager_ref.unpin.batch(*unpin_tasks)

async def fetch_batch(
self,
Expand All @@ -559,10 +692,8 @@ async def fetch_batch(
if error not in ("raise", "ignore"): # pragma: no cover
raise ValueError("error must be raise or ignore")

meta_api = await self._get_meta_api(session_id)
remote_keys = defaultdict(set)
missing_keys = []
get_metas = []
get_info_delays = []
for data_key in data_keys:
get_info_delays.append(
Expand All @@ -586,6 +717,9 @@ async def fetch_batch(
else:
# Not exists in local, fetch from remote worker
missing_keys.append(data_key)
await self._data_manager_ref.pin.batch(*pin_delays)

meta_api = await self._get_meta_api(session_id)
if address is None or band_name is None:
# some mapper keys are absent, specify error='ignore'
# remember that meta only records those main keys
Expand All @@ -599,16 +733,14 @@ async def fetch_batch(
)
for data_key in missing_keys
]
await self._data_manager_ref.pin.batch(*pin_delays)

if get_metas:
metas = await meta_api.get_chunk_meta.batch(*get_metas)
else: # pragma: no cover
metas = [{"bands": [(address, band_name)]}] * len(missing_keys)
assert len(metas) == len(missing_keys)
for data_key, bands in zip(missing_keys, metas):
if bands is not None:
remote_keys[bands["bands"][0]].add(data_key)

transfer_tasks = []
fetch_keys = []
for band, keys in remote_keys.items():
Expand All @@ -620,7 +752,7 @@ async def fetch_batch(
else:
# fetch via transfer
transfer_tasks.append(
self._fetch_via_transfer(
self.fetch_via_transfer(
session_id, list(keys), level, band, band_name or band[1], error
)
)
Expand Down
Loading
Loading