diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 2fd135d5475d..851a0b595bc6 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -9,6 +9,7 @@ import torch.distributed as dist from torch.distributed import ProcessGroup from torch.distributed import distributed_c10d as c10d +from version_parser.version import Version from .stage_manager import PipelineStageManager @@ -61,17 +62,6 @@ def _broadcast_object_list(object_list: List[Any], c10d._warn_not_in_group("broadcast_object_list") return - my_rank = dist.get_rank() - # Serialize object_list elements to tensors on src rank. - if my_rank == src: - if torch.__version__ >= "1.13.0": - tensor_list, size_list = zip(*[c10d._object_to_tensor(obj, device=device) for obj in object_list]) - else: - tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list]) - object_sizes_tensor = torch.cat(size_list) - else: - object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long) - is_nccl_backend = c10d._check_for_nccl_backend(group) current_device = None @@ -83,6 +73,18 @@ def _broadcast_object_list(object_list: List[Any], current_device = torch.device("cpu") if is_nccl_backend: current_device = torch.device("cuda", torch.cuda.current_device()) + + my_rank = dist.get_rank() + # Serialize object_list elements to tensors on src rank. + if my_rank == src: + if Version(torch.__version__) >= Version("1.13.0"): + tensor_list, size_list = zip(*[c10d._object_to_tensor(obj, device=current_device) for obj in object_list]) + else: + tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list]) + object_sizes_tensor = torch.cat(size_list) + else: + object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long) + if is_nccl_backend: object_sizes_tensor = object_sizes_tensor.to(current_device) diff --git a/tests/test_shardformer/test_model/test_pure_pipeline.py b/tests/test_shardformer/test_model/test_pure_pipeline.py index 80767f71c3fb..2f51eb9b02f7 100644 --- a/tests/test_shardformer/test_model/test_pure_pipeline.py +++ b/tests/test_shardformer/test_model/test_pure_pipeline.py @@ -1,3 +1,4 @@ +import copy import random from contextlib import nullcontext from typing import Any, Callable, Iterator, List, Optional, Tuple @@ -6,7 +7,6 @@ import pytest import torch import torch.distributed as dist -from torch import Tensor from torch.nn import Module from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler @@ -94,10 +94,10 @@ def execute_pipeline( return outputs -class data_iter(): +class data_loader(): def __getitem__(self, x): - return torch.randint(0, 100, (4, 128)).cuda() + return torch.ones((4, 128), dtype=torch.int).cuda() * 10 def loss(x, y): @@ -127,20 +127,30 @@ def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_la stage_manager = PipelineStageManager(pg_mesh, PP_DIM) sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + if name != 'transformers_llama': + continue num_microbatches = 2 org_model = model_fn().cuda() + data_iter = iter(data_loader()) + + model_copy = copy.deepcopy(org_model) + batch = next(data_iter) + with torch.no_grad(): + y = model_copy(batch) + org_loss = loss(batch, y) optimizer = torch.optim.AdamW(org_model.parameters(), lr=1e-3) - #dataloader=prepare_dataloader(dataset=dataset['train'],batch_size=4) schedule = OneForwardOneBackwardSchedule(num_microbatches, stage_manager) shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, enable_tensor_parallelism=enable_tensor_parallelism, pipeline_stage_manager=stage_manager) pipelined_model = PipelinedModel(org_model, shard_config, stage_manager) pp_optimizer = PipelineOptimizer(optimizer, pipelined_model) - data_it = iter(data_iter()) - results = execute_pipeline(data_it, pipelined_model, loss, pp_optimizer, schedule=schedule) + results = execute_pipeline(data_iter, pipelined_model, loss, pp_optimizer, schedule=schedule) + if stage_manager.is_last_stage(): - assert results['loss'] is not None + assert results['loss'] == org_loss + else: + assert results['loss'] is None assert results['outputs'] is None torch.cuda.empty_cache()