From 24640b3c5d7f6b5cc71a0091405f13123361bec9 Mon Sep 17 00:00:00 2001 From: Nadav Elyahu <88962733+nelyahu@users.noreply.github.com> Date: Wed, 8 Nov 2023 20:13:45 +0200 Subject: [PATCH] add option to disable pipeline partitioning (#4322) added pipeline configuration to deepspeed config under pipeline section pipe_partitioned = "auto" grad_partitioned = "auto" can be used to enable or disable activations and grads partitioning in TP/PP mode Co-authored-by: Olatunji Ruwase Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/runtime/config.py | 2 ++ deepspeed/runtime/pipe/engine.py | 8 ++++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index a1aa60760053..629f658e0790 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -449,6 +449,8 @@ def get_pipeline_config(param_dict): "partition": "best", "seed_layers": False, "activation_checkpoint_interval": 0, + "pipe_partitioned": True, + "grad_partitioned": True, } config = default_pipeline for key, val in param_dict.get("pipeline", {}).items(): diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index a4384f6ea711..fa9ff1fffe4b 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -133,8 +133,12 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): # Partition input/output buffers # XXX temporarily disable while I revert some partition hacks. - self.is_pipe_partitioned = self.is_model_parallel - self.is_grad_partitioned = self.is_model_parallel + assert isinstance(self._config.pipeline['pipe_partitioned'], bool) + assert isinstance(self._config.pipeline['grad_partitioned'], bool) + self.is_pipe_partitioned = self.is_model_parallel and self._config.pipeline['pipe_partitioned'] + self.is_grad_partitioned = self.is_model_parallel and self._config.pipeline['grad_partitioned'] + logger.info(f'is_pipe_partitioned= {self.is_pipe_partitioned}', + f'is_grad_partitioned= {self.is_grad_partitioned}') model_parameters = filter(lambda p: p.requires_grad, self.module.parameters()) num_params = sum([p.numel() for p in model_parameters])