Skip to content

Commit

Permalink
add option to disable pipeline partitioning (microsoft#4322)
Browse files Browse the repository at this point in the history
added pipeline configuration to deepspeed config under pipeline section
pipe_partitioned = "auto"
grad_partitioned = "auto"
can be used to enable or disable activations and grads partitioning in
TP/PP mode

Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
3 people authored and amaurya committed Feb 17, 2024
1 parent 59ef4b9 commit 24640b3
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
2 changes: 2 additions & 0 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,8 @@ def get_pipeline_config(param_dict):
"partition": "best",
"seed_layers": False,
"activation_checkpoint_interval": 0,
"pipe_partitioned": True,
"grad_partitioned": True,
}
config = default_pipeline
for key, val in param_dict.get("pipeline", {}).items():
Expand Down
8 changes: 6 additions & 2 deletions deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,12 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs):

# Partition input/output buffers
# XXX temporarily disable while I revert some partition hacks.
self.is_pipe_partitioned = self.is_model_parallel
self.is_grad_partitioned = self.is_model_parallel
assert isinstance(self._config.pipeline['pipe_partitioned'], bool)
assert isinstance(self._config.pipeline['grad_partitioned'], bool)
self.is_pipe_partitioned = self.is_model_parallel and self._config.pipeline['pipe_partitioned']
self.is_grad_partitioned = self.is_model_parallel and self._config.pipeline['grad_partitioned']
logger.info(f'is_pipe_partitioned= {self.is_pipe_partitioned}',
f'is_grad_partitioned= {self.is_grad_partitioned}')

model_parameters = filter(lambda p: p.requires_grad, self.module.parameters())
num_params = sum([p.numel() for p in model_parameters])
Expand Down

0 comments on commit 24640b3

Please sign in to comment.