Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Quality] Avoid torch.distributed imports at root #1134

Merged
merged 1 commit into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

import orjson as json
import torch
import torch.distributed as dist

from tensordict.memmap import MemoryMappedTensor

Expand Down Expand Up @@ -2388,7 +2387,7 @@ def _send(
dst: int,
_tag: int = -1,
pseudo_rand: bool = False,
group: "dist.ProcessGroup" | None = None,
group: "torch.distributed.ProcessGroup" | None = None,
) -> int:
for td in self.tensordicts:
_tag = td._send(dst, _tag=_tag, pseudo_rand=pseudo_rand, group=group)
Expand All @@ -2400,7 +2399,7 @@ def _isend(
_tag: int = -1,
_futures: list[torch.Future] | None = None,
pseudo_rand: bool = False,
group: "dist.ProcessGroup" | None = None,
group: "torch.distributed.ProcessGroup" | None = None,
) -> int:
if _futures is None:
is_root = True
Expand All @@ -2421,7 +2420,7 @@ def _recv(
src: int,
_tag: int = -1,
pseudo_rand: bool = False,
group: "dist.ProcessGroup" | None = None,
group: "torch.distributed.ProcessGroup" | None = None,
) -> int:
for td in self.tensordicts:
_tag = td._recv(src, _tag=_tag, pseudo_rand=pseudo_rand, group=group)
Expand All @@ -2434,7 +2433,7 @@ def _irecv(
_tag: int = -1,
_future_list: list[torch.Future] = None,
pseudo_rand: bool = False,
group: "dist.ProcessGroup" | None = None,
group: "torch.distributed.ProcessGroup" | None = None,
) -> tuple[int, list[torch.Future]] | list[torch.Future] | None:
root = False
if _future_list is None:
Expand Down
34 changes: 24 additions & 10 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
unravel_key,
unravel_key_list,
)
from torch import distributed as dist, multiprocessing as mp, nn, Tensor
from torch import multiprocessing as mp, nn, Tensor
from torch.nn.parameter import Parameter, UninitializedTensorMixin
from torch.utils._pytree import tree_map

Expand Down Expand Up @@ -7260,7 +7260,7 @@ def del_(self, key: NestedKey) -> T:

# Distributed functionality
def gather_and_stack(
self, dst: int, group: "dist.ProcessGroup" | None = None
self, dst: int, group: "torch.distributed.ProcessGroup" | None = None
) -> T | None:
"""Gathers tensordicts from various workers and stacks them onto self in the destination worker.

Expand Down Expand Up @@ -7319,6 +7319,8 @@ def gather_and_stack(
... main_worker.join()
... secondary_worker.join()
"""
from torch import distributed as dist

output = (
[None for _ in range(dist.get_world_size(group=group))]
if dst == dist.get_rank(group=group)
Expand All @@ -7336,7 +7338,7 @@ def send(
self,
dst: int,
*,
group: "dist.ProcessGroup" | None = None,
group: "torch.distributed.ProcessGroup" | None = None,
init_tag: int = 0,
pseudo_rand: bool = False,
) -> None: # noqa: D417
Expand Down Expand Up @@ -7426,8 +7428,10 @@ def _send(
dst: int,
_tag: int = -1,
pseudo_rand: bool = False,
group: "dist.ProcessGroup" | None = None,
group: "torch.distributed.ProcessGroup" | None = None,
) -> int:
from torch import distributed as dist

for key in self.sorted_keys:
value = self._get_str(key, NO_DEFAULT)
if isinstance(value, Tensor):
Expand All @@ -7449,7 +7453,7 @@ def recv(
self,
src: int,
*,
group: "dist.ProcessGroup" | None = None,
group: "torch.distributed.ProcessGroup" | None = None,
init_tag: int = 0,
pseudo_rand: bool = False,
) -> int: # noqa: D417
Expand Down Expand Up @@ -7481,9 +7485,11 @@ def _recv(
src: int,
_tag: int = -1,
pseudo_rand: bool = False,
group: "dist.ProcessGroup" | None = None,
group: "torch.distributed.ProcessGroup" | None = None,
non_blocking: bool = False,
) -> int:
from torch import distributed as dist

for key in self.sorted_keys:
value = self._get_str(key, NO_DEFAULT)
if isinstance(value, Tensor):
Expand All @@ -7508,7 +7514,7 @@ def isend(
self,
dst: int,
*,
group: "dist.ProcessGroup" | None = None,
group: "torch.distributed.ProcessGroup" | None = None,
init_tag: int = 0,
pseudo_rand: bool = False,
) -> int: # noqa: D417
Expand Down Expand Up @@ -7603,8 +7609,10 @@ def _isend(
_tag: int = -1,
_futures: list[torch.Future] | None = None,
pseudo_rand: bool = False,
group: "dist.ProcessGroup" | None = None,
group: "torch.distributed.ProcessGroup" | None = None,
) -> int:
from torch import distributed as dist

root = False
if _futures is None:
root = True
Expand Down Expand Up @@ -7639,7 +7647,7 @@ def irecv(
self,
src: int,
*,
group: "dist.ProcessGroup" | None = None,
group: "torch.distributed.ProcessGroup" | None = None,
return_premature: bool = False,
init_tag: int = 0,
pseudo_rand: bool = False,
Expand Down Expand Up @@ -7687,8 +7695,10 @@ def _irecv(
_tag: int = -1,
_future_list: list[torch.Future] = None,
pseudo_rand: bool = False,
group: "dist.ProcessGroup" | None = None,
group: "torch.distributed.ProcessGroup" | None = None,
) -> tuple[int, list[torch.Future]] | list[torch.Future] | None:
from torch import distributed as dist

root = False
if _future_list is None:
_future_list = []
Expand Down Expand Up @@ -7736,6 +7746,8 @@ def reduce(
Only the process with ``rank`` dst is going to receive the final result.

"""
from torch import distributed as dist

if op is None:
op = dist.ReduceOp.SUM
return self._reduce(dst, op, async_op, return_premature, group=group)
Expand All @@ -7749,6 +7761,8 @@ def _reduce(
_future_list=None,
group=None,
):
from torch import distributed as dist

if op is None:
op = dist.ReduceOp.SUM
root = False
Expand Down
4 changes: 2 additions & 2 deletions tensordict/tensorclass.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ from tensordict.utils import (
unravel_key as unravel_key,
unravel_key_list as unravel_key_list,
)
from torch import distributed as dist, multiprocessing as mp, nn, Tensor
from torch import multiprocessing as mp, nn, Tensor

class _NoDefault(enum.IntEnum):
ZERO = 0
Expand Down Expand Up @@ -663,7 +663,7 @@ class TensorClass:
) -> T: ...
def del_(self, key: NestedKey) -> T: ...
def gather_and_stack(
self, dst: int, group: dist.ProcessGroup | None = None
self, dst: int, group: "dist.ProcessGroup" | None = None
) -> T | None: ...
def send(
self,
Expand Down
9 changes: 9 additions & 0 deletions test/smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import sys

import pytest

Expand All @@ -11,6 +12,14 @@ def test_imports():
from tensordict import TensorDict # noqa: F401
from tensordict.nn import TensorDictModule # noqa: F401

# # Check that distributed is not imported
# v = set(sys.modules.values())
# try:
# from torch import distributed
# except ImportError:
# return
# assert distributed not in v


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
Loading