Skip to content

Commit

Permalink
[plugin] hybrid support zero bubble pipeline (#6060)
Browse files Browse the repository at this point in the history
* hybrid support zbv

* fix

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* Update zero_bubble_pp.py

* fix

* fix-ci

* fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

* [zerobubble]Support ZeroBubble Pipeline (#6034)

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [feat] add dw test;

* [fix] fix weight not close;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;

* [feat] add test for p & p grad;

* [feat] add comments for ZBV func;

* [fix] rm useless assign and comments;

* [fix] fix ci test; add pytest;

* [feat] add run_fwd_bwd_with_microbatch  (replace input) & test; add p&p.grad assert close test & all pass;

* [feat] add apply v_schedule graph; p & p.grad assert err exist;

* [fix] update

* [feat] fix ci; add assert;

* [feat] fix poc format

* [feat] fix func name & ci; add comments;

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [feat] add fwd_bwd_step, run_fwd_only;

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [fix] fix communication_map;

* [feat] update test; rm comments;

* [fix] rm zbv in hybridplugin

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix detach output & release output;

* [fix] rm requir_grad for output;

* [fix] fix requir grad position and detach position and input&output local buffer append position;

* [feat] add memory assertation;

* [fix] fix mem check;

* [fix] mem assertation'

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [fix] fix redundant detach & clone; add buffer assertation in the end;

* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;

* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;

* [fix] add testcase with microbatch 4;

* hybrid support zbv

* fix

fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update zero_bubble_pp.py

* fix

* fix-ci

* fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: duanjunwen <[email protected]>
  • Loading branch information
3 people authored Sep 27, 2024
1 parent b804fdc commit af6aa9e
Show file tree
Hide file tree
Showing 15 changed files with 140 additions and 53 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ jobs:
- name: Install Colossal-AI
run: |
BUILD_EXT=1 pip install -v -e .
BUILD_EXT=1 pip install -v .
pip install --no-cache-dir -r requirements/requirements-test.txt
- name: Store Colossal-AI Cache
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build_on_schedule.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ jobs:
if: steps.check-avai.outputs.avai == 'true'
run: |
[ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/
BUILD_EXT=1 pip install -v -e .
BUILD_EXT=1 pip install -v .
cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/
pip install --no-cache-dir -r requirements/requirements-test.txt
Expand Down
2 changes: 1 addition & 1 deletion colossalai/amp/naive_amp/mixed_precision_mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def zero_grad(self):
dtype: torch.dtype

@abstractmethod
def pre_backward(self, loss: Tensor) -> Tensor:
def pre_backward(self, loss: Tensor, *args, **kwargs) -> Tensor:
"""Called before backward.
Args:
Expand Down
13 changes: 9 additions & 4 deletions colossalai/amp/naive_amp/mixed_precision_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,18 @@ def __init__(
master_params.append(master_p)
group["params"] = master_params

def backward(self, loss: Tensor, *args, **kwargs):
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
loss = self.mixed_precision.pre_backward(loss)
loss.backward(*args, **kwargs)
loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)

def backward_by_grad(self, tensor: Tensor, grad: Tensor):
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
tensor.backward(grad)
torch.autograd.backward(
tensors=tensor,
grad_tensors=grad,
inputs=inputs,
retain_graph=retain_graph,
)

def zero_grad(self, *args, **kwargs):
for p in self.working_to_master_map.keys():
Expand Down
4 changes: 2 additions & 2 deletions colossalai/booster/mixed_precision/fp16_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def __init__(
growth_interval=growth_interval,
)

def backward(self, loss: Tensor, *args, **kwargs) -> None:
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs) -> None:
scaled_loss = self.scale_loss(loss)
scaled_loss.backward(*args, **kwargs)
scaled_loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)

def step(self, *args, **kwargs) -> Optional[float]:
out = self.scaler.step(self.optim, *args, **kwargs)
Expand Down
63 changes: 42 additions & 21 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
Expand Down Expand Up @@ -288,7 +288,7 @@ def __init__(
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
super().__init__(optim)

def backward(self, loss: Tensor, *args, **kwargs):
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
r"""
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
Expand All @@ -306,7 +306,7 @@ def backward(self, loss: Tensor, *args, **kwargs):
"""

# Call the superclass backward method to compute gradients.
super().backward(loss, *args, **kwargs)
super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)

if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
Expand All @@ -315,7 +315,7 @@ def backward(self, loss: Tensor, *args, **kwargs):
# If gradient synchronization is is not required, return.
return

def backward_by_grad(self, tensor: Tensor, grad: Tensor):
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
"""
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
Expand All @@ -332,7 +332,7 @@ def backward_by_grad(self, tensor: Tensor, grad: Tensor):
"""

# Call the superclass backward method to compute gradients.
super().backward_by_grad(tensor, grad)
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)

if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
Expand Down Expand Up @@ -512,7 +512,7 @@ def __init__(
max_norm=max_norm,
)

def backward(self, loss: Tensor, *args, **kwargs):
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
r"""
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
Expand All @@ -529,7 +529,7 @@ def backward(self, loss: Tensor, *args, **kwargs):
None
"""
# Call the superclass backward method to compute gradients.
super().backward(loss, *args, **kwargs)
super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)

if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
Expand All @@ -538,7 +538,7 @@ def backward(self, loss: Tensor, *args, **kwargs):
# If gradient synchronization is is not required, return.
return

def backward_by_grad(self, tensor: Tensor, grad: Tensor):
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
"""
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
Expand All @@ -554,7 +554,7 @@ def backward_by_grad(self, tensor: Tensor, grad: Tensor):
None
"""
# Call the superclass backward method to compute gradients.
super().backward_by_grad(tensor, grad)
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)

if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
Expand Down Expand Up @@ -768,7 +768,7 @@ def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]:
else:
return

def backward(self, loss, retain_graph=False):
def backward(self, loss, inputs=None, retain_graph=False):
"""
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
Expand All @@ -784,7 +784,7 @@ def backward(self, loss, retain_graph=False):
None
"""
# Call the superclass backward method to compute gradients.
super().backward(loss, retain_graph)
super().backward(loss, inputs=inputs, retain_graph=retain_graph)

if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
# If gradient synchronization is required, sync sequence parallelism gradients.
Expand All @@ -793,7 +793,7 @@ def backward(self, loss, retain_graph=False):
# If gradient synchronization is is not required, return.
return

def backward_by_grad(self, tensor, grad):
def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False):
"""
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
Expand All @@ -809,7 +809,7 @@ def backward_by_grad(self, tensor, grad):
None
"""
# Call the superclass backward_by_grad method to compute gradients.
super().backward_by_grad(tensor, grad)
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)

if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
# If gradient synchronization is required, sync sequence parallelism gradients.
Expand Down Expand Up @@ -1013,6 +1013,7 @@ def __init__(
custom_policy: Policy = None,
pp_style: str = "1f1b",
num_model_chunks: int = 1,
scheduler_nodes: List = None,
num_layers_per_stage: Optional[List[int]] = None,
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
enable_metadata_cache: bool = True,
Expand All @@ -1029,6 +1030,9 @@ def __init__(
dist.get_world_size() % (tp_size * pp_size) == 0
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"

assert (
not pp_style == "zbv" or scheduler_nodes is not None
), f"scheduler_nodes must not be None when using zero bubble pipeline."
if enable_sequence_parallelism:
self.sequence_parallelism_mode = (
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"
Expand Down Expand Up @@ -1088,29 +1092,39 @@ def __init__(
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)

self.stage_manager = None
self.schedule = None
self.scheduler = None
self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style"
assert (
pp_style in ["interleaved", "zbv"] or num_model_chunks == 1
), "num_model_chunks must be 1 when using 1f1b"
assert (
pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2
), "num_model_chunks must be 2 when using zero bubble pipeline"
assert (
num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
assert (
self.zero_stage <= 1
), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism"
if pp_style == "zbv":
self.logger.warning(
"""the enable_gradient_checkpointing function must set the use_reentrant to False, such as model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':False})"""
)
self.stage_manager = PipelineStageManager(
self.pg_mesh,
pipeline_axis=self.pp_axis,
enable_interleave=(pp_style == "interleaved"),
enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"),
use_zbv=(pp_style == "zbv"),
num_model_chunks=num_model_chunks,
num_layers_per_stage=num_layers_per_stage,
)

if pp_style == "interleaved":
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
self.schedule = InterleavedSchedule(
self.scheduler = InterleavedSchedule(
stage_manager=self.stage_manager,
num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches,
Expand All @@ -1119,12 +1133,20 @@ def __init__(
overlap_p2p=overlap_p2p,
)
elif pp_style == "1f1b":
self.schedule = OneForwardOneBackwardSchedule(
self.scheduler = OneForwardOneBackwardSchedule(
stage_manager=self.stage_manager,
num_microbatches=num_microbatches,
microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache,
)
elif pp_style == "zbv":
self.scheduler = ZeroBubbleVPipeScheduler(
stage_manager=self.stage_manager,
schedule=scheduler_nodes,
num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches,
microbatch_size=microbatch_size,
)
else:
raise NotImplementedError()
if sequence_parallelism_mode == "ring_attn":
Expand Down Expand Up @@ -1236,7 +1258,6 @@ def configure(

# Replace with distributed implementation if exists
optimizer = cast_to_distributed(optimizer)

if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
self.logger.warning(
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
Expand Down Expand Up @@ -1352,7 +1373,7 @@ def execute_pipeline(
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()

with ctx, model._wait_all_gather():
outputs = self.schedule.forward_backward_step(
outputs = self.scheduler.forward_backward_step(
model, data_iter, criterion, optimizer, return_loss, return_outputs
)

Expand Down
6 changes: 3 additions & 3 deletions colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def __init__(
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size)

self.stage_manager = None
self.schedule = None
self.scheduler = None
self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
Expand All @@ -304,7 +304,7 @@ def __init__(

if pp_style == "interleaved":
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
self.schedule = InterleavedSchedule(
self.scheduler = InterleavedSchedule(
stage_manager=self.stage_manager,
num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches,
Expand All @@ -313,7 +313,7 @@ def __init__(
overlap_p2p=overlap_p2p,
)
elif pp_style == "1f1b":
self.schedule = OneForwardOneBackwardSchedule(
self.scheduler = OneForwardOneBackwardSchedule(
stage_manager=self.stage_manager,
num_microbatches=num_microbatches,
microbatch_size=microbatch_size,
Expand Down
4 changes: 2 additions & 2 deletions colossalai/interface/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ def zero_grad(self, *args, **kwargs):
"""
self.optim.zero_grad(*args, **kwargs)

def backward(self, loss: Tensor, *args, **kwargs):
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
"""
Performs a backward pass on the loss.
"""
loss.backward(*args, **kwargs)
loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)

def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
"""
Expand Down
6 changes: 5 additions & 1 deletion colossalai/pipeline/stage_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,11 @@ def is_last_stage(self, ignore_chunk: bool = False) -> bool:
if not self.is_interleave or ignore_chunk:
return self.stage == self.num_stages - 1
else:
return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1
# use zero bubble pipeline
if self.use_zbv:
return self.stage == 0 and self.model_chunk_id == self.num_model_chunks - 1
else:
return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1

@property
def num_stages(self) -> int:
Expand Down
12 changes: 9 additions & 3 deletions colossalai/shardformer/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,9 @@ def get_held_layers(self) -> List[Module]:
held_layers.append(module.embed_tokens)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage(ignore_chunk=True):
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.norm)
elif stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(module.norm)

else:
Expand Down Expand Up @@ -351,7 +353,9 @@ def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage(ignore_chunk=True):
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
elif stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
return held_layers

Expand Down Expand Up @@ -404,7 +408,9 @@ def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage(ignore_chunk=True):
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(self.model.score)
elif stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.score)
return held_layers

Expand Down
2 changes: 1 addition & 1 deletion colossalai/zero/gemini/gemini_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def backward(self, loss: torch.Tensor):
loss.backward()
self._post_backward()

def backward_by_grad(self, tensor, grad):
def backward_by_grad(self, tensor, grad, inputs: torch.Tensor = None, retain_graph: bool = False):
raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.")

@staticmethod
Expand Down
6 changes: 4 additions & 2 deletions colossalai/zero/gemini/gemini_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,14 @@ def backward(self, loss: torch.Tensor):
loss = self.mix_precision_mixin.pre_backward(loss)
self.module.backward(loss)

def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor):
def backward_by_grad(
self, tensor: torch.Tensor, grad: torch.Tensor, inputs: torch.Tensor = None, retain_graph: bool = False
):
# This function is called except the last stage of pipeline parallel
# It receives the scaled grad from the previous rank
# No need to scale the grad again
# Need to unscale when optimizing
grad = self.mix_precision_mixin.pre_backward_by_grad(grad)
grad = self.mix_precision_mixin.pre_backward_by_grad(grad, inputs=inputs, retain_graph=retain_graph)
self.module.backward_by_grad(tensor, grad)

def _maybe_move_fp32_params(self):
Expand Down
Loading

0 comments on commit af6aa9e

Please sign in to comment.