From af6aa9ed0668d8f98e32bcaa807fc752913c7e0c Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 27 Sep 2024 14:48:55 +0800 Subject: [PATCH] [plugin] hybrid support zero bubble pipeline (#6060) * hybrid support zbv * fix fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * Update zero_bubble_pp.py * fix * fix-ci * fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix * [zerobubble]Support ZeroBubble Pipeline (#6034) * [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble; * [feat] add dw test; * [fix] fix weight not close; * [update] update text; * [feat] add test run_fwd_bwd automatic scheduling; * [feat] split communication and calculation; fix pop empty send_bwd_buffer error; * [feat] add test for p & p grad; * [feat] add comments for ZBV func; * [fix] rm useless assign and comments; * [fix] fix ci test; add pytest; * [feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p&p.grad assert close test & all pass; * [feat] add apply v_schedule graph; p & p.grad assert err exist; * [fix] update * [feat] fix ci; add assert; * [feat] fix poc format * [feat] fix func name & ci; add comments; * [fix] fix poc test; add comments in poc; * [feat] add optim backward_b_by_grad * [feat] fix optimizer bwd b & w; support return accum loss & output * [feat] add fwd_bwd_step, run_fwd_only; * [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict; * [fix] fix communication_map; * [feat] update test; rm comments; * [fix] rm zbv in hybridplugin * [fix] fix optim bwd; * [fix] fix optim bwd; * [fix] rm output.data after send fwd; * [fix] fix bwd step if condition; remove useless comments and format info; * [fix] fix detach output & release output; * [fix] rm requir_grad for output; * [fix] fix requir grad position and detach position and input&output local buffer append position; * [feat] add memory assertation; * [fix] fix mem check; * [fix] mem assertation' * [fix] fix mem assertation * [fix] fix mem; use a new model shape; only assert mem less and equal than theo; * [fix] fix model zoo import; * [fix] fix redundant detach & clone; add buffer assertation in the end; * [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap; * [fix] update optim state dict assert (include param group & state); fix mem assert after add optim; * [fix] add testcase with microbatch 4; * hybrid support zbv * fix fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update zero_bubble_pp.py * fix * fix-ci * fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: duanjunwen <935724073@qq.com> --- .github/workflows/build_on_pr.yml | 2 +- .github/workflows/build_on_schedule.yml | 2 +- .../naive_amp/mixed_precision_mixin/base.py | 2 +- .../naive_amp/mixed_precision_optimizer.py | 13 ++-- .../booster/mixed_precision/fp16_torch.py | 4 +- .../booster/plugin/hybrid_parallel_plugin.py | 63 ++++++++++++------- .../plugin/moe_hybrid_parallel_plugin.py | 6 +- colossalai/interface/optimizer.py | 4 +- colossalai/pipeline/stage_manager.py | 6 +- colossalai/shardformer/policies/llama.py | 12 +++- colossalai/zero/gemini/gemini_ddp.py | 2 +- colossalai/zero/gemini/gemini_optimizer.py | 6 +- colossalai/zero/low_level/low_level_optim.py | 13 ++-- tests/test_shardformer/test_model/_utils.py | 14 +++-- .../test_model/test_shard_llama.py | 44 ++++++++++++- 15 files changed, 140 insertions(+), 53 deletions(-) diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index 58cd8826809a..79d758c87976 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -140,7 +140,7 @@ jobs: - name: Install Colossal-AI run: | - BUILD_EXT=1 pip install -v -e . + BUILD_EXT=1 pip install -v . pip install --no-cache-dir -r requirements/requirements-test.txt - name: Store Colossal-AI Cache diff --git a/.github/workflows/build_on_schedule.yml b/.github/workflows/build_on_schedule.yml index fc688a71bd92..e7b5063279eb 100644 --- a/.github/workflows/build_on_schedule.yml +++ b/.github/workflows/build_on_schedule.yml @@ -55,7 +55,7 @@ jobs: if: steps.check-avai.outputs.avai == 'true' run: | [ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/ - BUILD_EXT=1 pip install -v -e . + BUILD_EXT=1 pip install -v . cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/ pip install --no-cache-dir -r requirements/requirements-test.txt diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/base.py b/colossalai/amp/naive_amp/mixed_precision_mixin/base.py index fc7e0b74179a..b2ba47f6762d 100644 --- a/colossalai/amp/naive_amp/mixed_precision_mixin/base.py +++ b/colossalai/amp/naive_amp/mixed_precision_mixin/base.py @@ -43,7 +43,7 @@ def zero_grad(self): dtype: torch.dtype @abstractmethod - def pre_backward(self, loss: Tensor) -> Tensor: + def pre_backward(self, loss: Tensor, *args, **kwargs) -> Tensor: """Called before backward. Args: diff --git a/colossalai/amp/naive_amp/mixed_precision_optimizer.py b/colossalai/amp/naive_amp/mixed_precision_optimizer.py index 9e07bdebf8fa..8fb56aee4fce 100644 --- a/colossalai/amp/naive_amp/mixed_precision_optimizer.py +++ b/colossalai/amp/naive_amp/mixed_precision_optimizer.py @@ -85,13 +85,18 @@ def __init__( master_params.append(master_p) group["params"] = master_params - def backward(self, loss: Tensor, *args, **kwargs): + def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs): loss = self.mixed_precision.pre_backward(loss) - loss.backward(*args, **kwargs) + loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs) - def backward_by_grad(self, tensor: Tensor, grad: Tensor): + def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): grad = self.mixed_precision.pre_backward_by_grad(tensor, grad) - tensor.backward(grad) + torch.autograd.backward( + tensors=tensor, + grad_tensors=grad, + inputs=inputs, + retain_graph=retain_graph, + ) def zero_grad(self, *args, **kwargs): for p in self.working_to_master_map.keys(): diff --git a/colossalai/booster/mixed_precision/fp16_torch.py b/colossalai/booster/mixed_precision/fp16_torch.py index c757a878d97a..a85d9f808546 100644 --- a/colossalai/booster/mixed_precision/fp16_torch.py +++ b/colossalai/booster/mixed_precision/fp16_torch.py @@ -46,9 +46,9 @@ def __init__( growth_interval=growth_interval, ) - def backward(self, loss: Tensor, *args, **kwargs) -> None: + def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs) -> None: scaled_loss = self.scale_loss(loss) - scaled_loss.backward(*args, **kwargs) + scaled_loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs) def step(self, *args, **kwargs) -> Optional[float]: out = self.scaler.step(self.optim, *args, **kwargs) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 1b3b765c2ff0..5d114ab9c315 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -28,7 +28,7 @@ from colossalai.interface.optimizer import DistributedOptim from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed -from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule +from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer @@ -288,7 +288,7 @@ def __init__( self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 super().__init__(optim) - def backward(self, loss: Tensor, *args, **kwargs): + def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs): r""" Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. @@ -306,7 +306,7 @@ def backward(self, loss: Tensor, *args, **kwargs): """ # Call the superclass backward method to compute gradients. - super().backward(loss, *args, **kwargs) + super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs) if self.model.require_grad_sync: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -315,7 +315,7 @@ def backward(self, loss: Tensor, *args, **kwargs): # If gradient synchronization is is not required, return. return - def backward_by_grad(self, tensor: Tensor, grad: Tensor): + def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): """ Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. @@ -332,7 +332,7 @@ def backward_by_grad(self, tensor: Tensor, grad: Tensor): """ # Call the superclass backward method to compute gradients. - super().backward_by_grad(tensor, grad) + super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph) if self.model.require_grad_sync: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -512,7 +512,7 @@ def __init__( max_norm=max_norm, ) - def backward(self, loss: Tensor, *args, **kwargs): + def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs): r""" Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. @@ -529,7 +529,7 @@ def backward(self, loss: Tensor, *args, **kwargs): None """ # Call the superclass backward method to compute gradients. - super().backward(loss, *args, **kwargs) + super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs) if self.model.require_grad_sync: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -538,7 +538,7 @@ def backward(self, loss: Tensor, *args, **kwargs): # If gradient synchronization is is not required, return. return - def backward_by_grad(self, tensor: Tensor, grad: Tensor): + def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): """ Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. @@ -554,7 +554,7 @@ def backward_by_grad(self, tensor: Tensor, grad: Tensor): None """ # Call the superclass backward method to compute gradients. - super().backward_by_grad(tensor, grad) + super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph) if self.model.require_grad_sync: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -768,7 +768,7 @@ def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]: else: return - def backward(self, loss, retain_graph=False): + def backward(self, loss, inputs=None, retain_graph=False): """ Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. @@ -784,7 +784,7 @@ def backward(self, loss, retain_graph=False): None """ # Call the superclass backward method to compute gradients. - super().backward(loss, retain_graph) + super().backward(loss, inputs=inputs, retain_graph=retain_graph) if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -793,7 +793,7 @@ def backward(self, loss, retain_graph=False): # If gradient synchronization is is not required, return. return - def backward_by_grad(self, tensor, grad): + def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False): """ Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. @@ -809,7 +809,7 @@ def backward_by_grad(self, tensor, grad): None """ # Call the superclass backward_by_grad method to compute gradients. - super().backward_by_grad(tensor, grad) + super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph) if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: # If gradient synchronization is required, sync sequence parallelism gradients. @@ -1013,6 +1013,7 @@ def __init__( custom_policy: Policy = None, pp_style: str = "1f1b", num_model_chunks: int = 1, + scheduler_nodes: List = None, num_layers_per_stage: Optional[List[int]] = None, gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, enable_metadata_cache: bool = True, @@ -1029,6 +1030,9 @@ def __init__( dist.get_world_size() % (tp_size * pp_size) == 0 ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" + assert ( + not pp_style == "zbv" or scheduler_nodes is not None + ), f"scheduler_nodes must not be None when using zero bubble pipeline." if enable_sequence_parallelism: self.sequence_parallelism_mode = ( sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all" @@ -1088,29 +1092,39 @@ def __init__( self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) self.stage_manager = None - self.schedule = None + self.scheduler = None self.custom_policy = custom_policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: - assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" - assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" + assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style" + assert ( + pp_style in ["interleaved", "zbv"] or num_model_chunks == 1 + ), "num_model_chunks must be 1 when using 1f1b" + assert ( + pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2 + ), "num_model_chunks must be 2 when using zero bubble pipeline" assert ( num_microbatches is not None or microbatch_size is not None ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" assert ( self.zero_stage <= 1 ), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism" + if pp_style == "zbv": + self.logger.warning( + """the enable_gradient_checkpointing function must set the use_reentrant to False, such as model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':False})""" + ) self.stage_manager = PipelineStageManager( self.pg_mesh, pipeline_axis=self.pp_axis, - enable_interleave=(pp_style == "interleaved"), + enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"), + use_zbv=(pp_style == "zbv"), num_model_chunks=num_model_chunks, num_layers_per_stage=num_layers_per_stage, ) if pp_style == "interleaved": assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" - self.schedule = InterleavedSchedule( + self.scheduler = InterleavedSchedule( stage_manager=self.stage_manager, num_model_chunks=num_model_chunks, num_microbatch=num_microbatches, @@ -1119,12 +1133,20 @@ def __init__( overlap_p2p=overlap_p2p, ) elif pp_style == "1f1b": - self.schedule = OneForwardOneBackwardSchedule( + self.scheduler = OneForwardOneBackwardSchedule( stage_manager=self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size, enable_metadata_cache=enable_metadata_cache, ) + elif pp_style == "zbv": + self.scheduler = ZeroBubbleVPipeScheduler( + stage_manager=self.stage_manager, + schedule=scheduler_nodes, + num_model_chunks=num_model_chunks, + num_microbatch=num_microbatches, + microbatch_size=microbatch_size, + ) else: raise NotImplementedError() if sequence_parallelism_mode == "ring_attn": @@ -1236,7 +1258,6 @@ def configure( # Replace with distributed implementation if exists optimizer = cast_to_distributed(optimizer) - if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0: self.logger.warning( "Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.", @@ -1352,7 +1373,7 @@ def execute_pipeline( ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() with ctx, model._wait_all_gather(): - outputs = self.schedule.forward_backward_step( + outputs = self.scheduler.forward_backward_step( model, data_iter, criterion, optimizer, return_loss, return_outputs ) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 56405ed47e00..fe12645374db 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -280,7 +280,7 @@ def __init__( self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size) self.stage_manager = None - self.schedule = None + self.scheduler = None self.custom_policy = custom_policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: @@ -304,7 +304,7 @@ def __init__( if pp_style == "interleaved": assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" - self.schedule = InterleavedSchedule( + self.scheduler = InterleavedSchedule( stage_manager=self.stage_manager, num_model_chunks=num_model_chunks, num_microbatch=num_microbatches, @@ -313,7 +313,7 @@ def __init__( overlap_p2p=overlap_p2p, ) elif pp_style == "1f1b": - self.schedule = OneForwardOneBackwardSchedule( + self.scheduler = OneForwardOneBackwardSchedule( stage_manager=self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size, diff --git a/colossalai/interface/optimizer.py b/colossalai/interface/optimizer.py index a236434a55d6..c8cf3ec21360 100644 --- a/colossalai/interface/optimizer.py +++ b/colossalai/interface/optimizer.py @@ -49,11 +49,11 @@ def zero_grad(self, *args, **kwargs): """ self.optim.zero_grad(*args, **kwargs) - def backward(self, loss: Tensor, *args, **kwargs): + def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs): """ Performs a backward pass on the loss. """ - loss.backward(*args, **kwargs) + loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs) def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False): """ diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index 50cc965bb9c3..5cc32114daff 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -136,7 +136,11 @@ def is_last_stage(self, ignore_chunk: bool = False) -> bool: if not self.is_interleave or ignore_chunk: return self.stage == self.num_stages - 1 else: - return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1 + # use zero bubble pipeline + if self.use_zbv: + return self.stage == 0 and self.model_chunk_id == self.num_model_chunks - 1 + else: + return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1 @property def num_stages(self) -> int: diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 67e6e92d1d36..60da448d8767 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -261,7 +261,9 @@ def get_held_layers(self) -> List[Module]: held_layers.append(module.embed_tokens) for start_idx, end_idx in stage_indices: held_layers.extend(module.layers[start_idx:end_idx]) - if stage_manager.is_last_stage(ignore_chunk=True): + if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.norm) + elif stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(module.norm) else: @@ -351,7 +353,9 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(ignore_chunk=True): + if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(self.model.lm_head) + elif stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(self.model.lm_head) return held_layers @@ -404,7 +408,9 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(ignore_chunk=True): + if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(self.model.score) + elif stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(self.model.score) return held_layers diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 80b2c7961e29..d2754cbd965b 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -373,7 +373,7 @@ def backward(self, loss: torch.Tensor): loss.backward() self._post_backward() - def backward_by_grad(self, tensor, grad): + def backward_by_grad(self, tensor, grad, inputs: torch.Tensor = None, retain_graph: bool = False): raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.") @staticmethod diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index fdf2a497626f..ccd4634b5fe2 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -298,12 +298,14 @@ def backward(self, loss: torch.Tensor): loss = self.mix_precision_mixin.pre_backward(loss) self.module.backward(loss) - def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor): + def backward_by_grad( + self, tensor: torch.Tensor, grad: torch.Tensor, inputs: torch.Tensor = None, retain_graph: bool = False + ): # This function is called except the last stage of pipeline parallel # It receives the scaled grad from the previous rank # No need to scale the grad again # Need to unscale when optimizing - grad = self.mix_precision_mixin.pre_backward_by_grad(grad) + grad = self.mix_precision_mixin.pre_backward_by_grad(grad, inputs=inputs, retain_graph=retain_graph) self.module.backward_by_grad(tensor, grad) def _maybe_move_fp32_params(self): diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 51d7d1eaaa33..9cc44c7538dd 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -408,7 +408,7 @@ def _add_to_bucket(self, param, group_id): # torch.optim.Optimizer methods ################################ - def backward(self, loss, retain_graph=False): + def backward(self, loss, inputs=None, retain_graph=False): assert not ( self._partition_grads and not self.require_grad_sync ), "ZeRO2(partition_grads) and no_sync are not compatible" @@ -416,7 +416,7 @@ def backward(self, loss, retain_graph=False): if self.mixed_precision_mixin is not None: loss = self.mixed_precision_mixin.pre_backward(loss) - loss.backward(retain_graph=retain_graph) + loss.backward(inputs=inputs, retain_graph=retain_graph) if not self.require_grad_sync: return @@ -427,14 +427,19 @@ def backward(self, loss, retain_graph=False): if self._overlap_communication: get_accelerator().synchronize() - def backward_by_grad(self, tensor, grad): + def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False): assert not ( self._partition_grads and not self.require_grad_sync ), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" if self.mixed_precision_mixin is not None: grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad) - torch.autograd.backward(tensor, grad) + torch.autograd.backward( + tensor, + grad, + inputs=inputs, + retain_graph=retain_graph, + ) if not self.require_grad_sync: return diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 9ad84341ac9e..5c141e8f5cf1 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -157,7 +157,6 @@ def build_model_from_hybrid_plugin( sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3) criterion = loss_fn - plugin = pluggin_cls(**test_config) booster = Booster(plugin=plugin) @@ -311,8 +310,16 @@ def check_output_hidden_state( ): org_hidden_state = org_output.last_hidden_state - if stage_manager and stage_manager.is_last_stage(ignore_chunk=True): - sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"] + if stage_manager: + if stage_manager.use_zbv: + if stage_manager.is_first_stage(ignore_chunk=True): + sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"] + else: + sharded_hidden_state = sharded_output.last_hidden_state + elif stage_manager.is_last_stage(ignore_chunk=True): + sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"] + else: + sharded_hidden_state = sharded_output.last_hidden_state else: sharded_hidden_state = sharded_output.last_hidden_state @@ -390,7 +397,6 @@ def get_grad_tensors_for_check( pass if verbose and dist.get_rank() == 0: print(f"'{suffix}' grad: {org_grad}, {shard_grad}") - grad_to_check[suffix] = { "org_grad": org_grad.float(), "shard_grad": shard_grad.float(), diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 3c66f609787a..d925687cd875 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -7,6 +7,7 @@ import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.schedule.v_schedule import PipelineGraph from colossalai.shardformer import PipelineGradientCheckpointConfig from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter @@ -33,7 +34,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) if enable_gradient_checkpointing: # org_model.gradient_checkpointing_enable() - sharded_model.unwrap().gradient_checkpointing_enable() + sharded_model.unwrap().gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster @@ -112,12 +113,20 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, sharded_optimizer.step() # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True): + check_flag = False + if stage_manager is None: + check_flag = True + else: + if stage_manager.use_zbv: + if stage_manager.is_first_stage(ignore_chunk=True): + check_flag = True + elif stage_manager.is_last_stage(ignore_chunk=True): + check_flag = True + if check_flag: if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ == "LlamaModel": check_output_hidden_state( org_output, @@ -282,10 +291,39 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, }, + { + "tp_size": 2, + "pp_size": 2, + "pp_style": "zbv", + "num_model_chunks": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "precision": "fp16", + "zero_stage": 0, + "initial_scale": 1, + "enable_gradient_checkpointing": True, + "parallel_output": False, + }, ], ) def run_llama_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") + if test_config.get("pp_style", None) == "zbv": + mem_f = 34 * 32 + 5 * 4 * 16 + mem_w = -32 * 32 + mem_b = -mem_w - mem_f + scheduler_nodes = PipelineGraph( + n_stage=test_config["pp_size"], + n_micro=test_config["num_microbatches"], + f_cost=1000, + b_cost=1000, + w_cost=1000, + c_cost=1, + f_mem=mem_f, + b_mem=mem_b, + w_mem=mem_w, + ).get_v_schedule() + test_config["scheduler_nodes"] = scheduler_nodes for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name: continue