diff --git a/torchrec/distributed/composable/tests/test_ddp.py b/torchrec/distributed/composable/tests/test_ddp.py index 4c77fd5ed..b6291afc9 100644 --- a/torchrec/distributed/composable/tests/test_ddp.py +++ b/torchrec/distributed/composable/tests/test_ddp.py @@ -9,13 +9,10 @@ #!/usr/bin/env python3 -import os import tempfile import unittest -import uuid import torch -from torch import distributed as dist from torch.distributed._composable import replicate from torch.distributed._shard.api import ShardedTensor from torch.distributed.checkpoint import ( @@ -24,167 +21,142 @@ load_state_dict, save_state_dict, ) -from torch.distributed.launcher.api import elastic_launch, LaunchConfig from torchrec.distributed.shard import shard as trec_shard, shard_modules from torchrec.distributed.sharding_plan import column_wise +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) from torchrec.distributed.test_utils.test_model import ModelInput, TestSparseNN from torchrec.modules.embedding_configs import EmbeddingBagConfig from torchrec.test_utils import skip_if_asan -class DDPTest(unittest.TestCase): +class DDPTest(MultiProcessTestBase): @classmethod - def _run_init_parameters(cls, path: str) -> None: - rank = int(os.environ["LOCAL_RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - if torch.cuda.is_available(): - device: torch.device = torch.device(f"cuda:{rank}") - backend = "nccl" - torch.cuda.set_device(device) - else: - device: torch.device = torch.device("cpu") - backend = "gloo" - dist.init_process_group( - backend=backend, - rank=rank, - world_size=world_size, - init_method=f"file://{os.path.join(path, 'dist_rdvz')}", - ) - num_float_features = 32 - - tables = [ - EmbeddingBagConfig( - num_embeddings=(i + 1) * 10, - embedding_dim=(i + 1) * 4 * world_size, - name="table_" + str(i), - feature_names=["feature_" + str(i)], - ) - for i in range(3) - ] - weighted_tables = [ - EmbeddingBagConfig( - num_embeddings=(i + 1) * 10, - embedding_dim=(i + 1) * 4 * world_size, - name="weighted_table_" + str(i), - feature_names=["weighted_feature_" + str(i)], - ) - for i in range(2) - ] - m = TestSparseNN( - tables=tables, - num_float_features=num_float_features, - weighted_tables=weighted_tables, - dense_device=device, - ) - # Put all tensors on meta device, then init_params should - # materialize them. - for name, param in m._parameters.items(): - if isinstance(param, torch.Tensor): - m._parameters[name] = torch.nn.Parameter( - torch.empty_like(param, device="meta"), - requires_grad=param.requires_grad, + def _run_init(cls, rank: int, world_size: int) -> None: + with MultiProcessContext(rank, world_size, "nccl") as ctx: + num_float_features = 32 + + tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=(i + 1) * 4 * world_size, + name="table_" + str(i), + feature_names=["feature_" + str(i)], ) - - shard_modules(m, device=device, init_params=True) - # init_params should move m to `device` - for p in m.parameters(): - assert p.device == device + for i in range(3) + ] + weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=(i + 1) * 4 * world_size, + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(2) + ] + m = TestSparseNN( + tables=tables, + num_float_features=num_float_features, + weighted_tables=weighted_tables, + dense_device=ctx.device, + ) + # Put all tensors on meta device, then init_params should + # materialize them. + for name, param in m._parameters.items(): + if isinstance(param, torch.Tensor): + m._parameters[name] = torch.nn.Parameter( + torch.empty_like(param, device="meta"), + requires_grad=param.requires_grad, + ) + + shard_modules(m, device=ctx.device, init_params=True) + # init_params should move m to `device` + for p in m.parameters(): + assert p.device == ctx.device @classmethod - def _run(cls, path: str) -> None: - rank = int(os.environ["LOCAL_RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - if torch.cuda.is_available(): - device: torch.device = torch.device(f"cuda:{rank}") - backend = "nccl" - torch.cuda.set_device(device) - else: - device: torch.device = torch.device("cpu") - backend = "gloo" - dist.init_process_group( - backend=backend, - rank=rank, - world_size=world_size, - init_method=f"file://{os.path.join(path, 'dist_rdvz')}", - ) - num_float_features = 32 - - tables = [ - EmbeddingBagConfig( - num_embeddings=(i + 1) * 10, - embedding_dim=(i + 1) * 4 * world_size, - name="table_" + str(i), - feature_names=["feature_" + str(i)], + def _run(cls, rank: int, world_size: int, path: str) -> None: + with MultiProcessContext(rank, world_size, "nccl") as ctx: + num_float_features = 32 + + tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=(i + 1) * 4 * world_size, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(3) + ] + weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=(i + 1) * 4 * world_size, + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(2) + ] + m = TestSparseNN( + tables=tables, + num_float_features=num_float_features, + weighted_tables=weighted_tables, + dense_device=ctx.device, ) - for i in range(3) - ] - weighted_tables = [ - EmbeddingBagConfig( - num_embeddings=(i + 1) * 10, - embedding_dim=(i + 1) * 4 * world_size, - name="weighted_table_" + str(i), - feature_names=["weighted_feature_" + str(i)], + m.sparse.ebc = trec_shard( + module=m.sparse.ebc, + device=ctx.device, + plan=column_wise(ranks=list(range(world_size))), ) - for i in range(2) - ] - m = TestSparseNN( - tables=tables, - num_float_features=num_float_features, - weighted_tables=weighted_tables, - dense_device=device, - ) - m.sparse.ebc = trec_shard( - module=m.sparse.ebc, - device=device, - plan=column_wise(ranks=list(range(world_size))), - ) - m.sparse.weighted_ebc = trec_shard( - module=m.sparse.weighted_ebc, - device=device, - plan=column_wise(ranks=list(range(world_size))), - ) - m.over = replicate(m.over) - m.dense = replicate(m.dense) - - ######## run one iteration ######## - _, local_batch = ModelInput.generate( - batch_size=8, - world_size=world_size, - num_float_features=num_float_features, - tables=tables, - weighted_tables=weighted_tables, - ) - batch = local_batch[0].to(device) - m(batch)[1].sum().backward() - - state_dict = m.state_dict() - writer = FileSystemWriter(path=path) - reader = FileSystemReader(path=path) - save_state_dict(state_dict, writer) - - p_sum = torch.zeros(1, device=device) - for p in m.parameters(): - with torch.no_grad(): - if isinstance(p, ShardedTensor): - if not p.local_shards(): - continue - p = p.local_tensor() - p_sum += p.sum() - p.zero_() - assert p.sum() == 0 - load_state_dict(state_dict, reader) - m.load_state_dict(state_dict) - - p_sum_loaded = torch.zeros(1, device=device) - for p in m.parameters(): - with torch.no_grad(): - if isinstance(p, ShardedTensor): - if not p.local_shards(): - continue - p = p.local_tensor() - p_sum_loaded += p.sum() - # TODO: debug why failing on OSS - # assert p_sum.allclose(p_sum_loaded) + m.sparse.weighted_ebc = trec_shard( + module=m.sparse.weighted_ebc, + device=ctx.device, + plan=column_wise(ranks=list(range(world_size))), + ) + m.over = replicate(m.over) + m.dense = replicate(m.dense) + + ######## run one iteration ######## + _, local_batch = ModelInput.generate( + batch_size=8, + world_size=world_size, + num_float_features=num_float_features, + tables=tables, + weighted_tables=weighted_tables, + ) + batch = local_batch[0].to(ctx.device) + m(batch)[1].sum().backward() + + state_dict = m.state_dict() + writer = FileSystemWriter(path=path) + reader = FileSystemReader(path=path) + save_state_dict(state_dict, writer) + + p_sum = torch.zeros(1, device=ctx.device) + for p in m.parameters(): + with torch.no_grad(): + if isinstance(p, ShardedTensor): + if not p.local_shards(): + continue + p = p.local_tensor() + p_sum += p.sum() + p.zero_() + assert p.sum() == 0 + load_state_dict(state_dict, reader) + m.load_state_dict(state_dict) + + p_sum_loaded = torch.zeros(1, device=ctx.device) + for p in m.parameters(): + with torch.no_grad(): + if isinstance(p, ShardedTensor): + if not p.local_shards(): + continue + p = p.local_tensor() + p_sum_loaded += p.sum() + # TODO: debug why failing on OSS + # assert p_sum.allclose(p_sum_loaded) @skip_if_asan # pyre-fixme[56]: Pyre was not able to infer the type of argument @@ -195,18 +167,10 @@ def _run(cls, path: str) -> None: ) def test_checkpoint(self) -> None: with tempfile.TemporaryDirectory() as path: - lc = LaunchConfig( - min_nodes=1, - max_nodes=1, - nproc_per_node=2, - run_id=str(uuid.uuid4()), - rdzv_backend="c10d", - rdzv_endpoint="localhost:0", - start_method="spawn", - monitor_interval=1, - max_restarts=0, + self._run_multi_process_test( + callable=self._run, + path=path, ) - elastic_launch(config=lc, entrypoint=self._run)(path) @skip_if_asan # pyre-fixme[56]: Pyre was not able to infer the type of argument @@ -216,15 +180,6 @@ def test_checkpoint(self) -> None: "Not enough GPUs, this test requires at least two GPUs", ) def test_init_params(self) -> None: - with tempfile.TemporaryDirectory() as path: - lc = LaunchConfig( - min_nodes=1, - max_nodes=1, - nproc_per_node=2, - run_id=str(uuid.uuid4()), - rdzv_backend="c10d", - start_method="spawn", - monitor_interval=1, - max_restarts=0, - ) - elastic_launch(config=lc, entrypoint=self._run_init_parameters)(path) + self._run_multi_process_test( + callable=self._run_init, + )