Skip to content

Commit

Permalink
[feat] update optimizer bwd; ä¸
Browse files Browse the repository at this point in the history
  • Loading branch information
duanjunwen committed Sep 29, 2024
1 parent d634795 commit 5c8bbf6
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 16 deletions.
4 changes: 2 additions & 2 deletions colossalai/interface/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ def zero_grad(self, *args, **kwargs):
"""
self.optim.zero_grad(*args, **kwargs)

def backward(self, loss: Tensor, *args, **kwargs):
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
"""
Performs a backward pass on the loss.
"""
loss.backward(*args, **kwargs)
loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)

def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
"""
Expand Down
2 changes: 1 addition & 1 deletion colossalai/zero/gemini/gemini_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def backward(self, loss: torch.Tensor):
loss.backward()
self._post_backward()

def backward_by_grad(self, tensor, grad):
def backward_by_grad(self, tensor, grad, inputs: torch.Tensor = None, retain_graph: bool = False):
raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.")

@staticmethod
Expand Down
6 changes: 4 additions & 2 deletions colossalai/zero/gemini/gemini_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,14 @@ def backward(self, loss: torch.Tensor):
loss = self.mix_precision_mixin.pre_backward(loss)
self.module.backward(loss)

def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor):
def backward_by_grad(
self, tensor: torch.Tensor, grad: torch.Tensor, inputs: torch.Tensor = None, retain_graph: bool = False
):
# This function is called except the last stage of pipeline parallel
# It receives the scaled grad from the previous rank
# No need to scale the grad again
# Need to unscale when optimizing
grad = self.mix_precision_mixin.pre_backward_by_grad(grad)
grad = self.mix_precision_mixin.pre_backward_by_grad(grad, inputs=inputs, retain_graph=retain_graph)
self.module.backward_by_grad(tensor, grad)

def _maybe_move_fp32_params(self):
Expand Down
13 changes: 9 additions & 4 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,15 +408,15 @@ def _add_to_bucket(self, param, group_id):
# torch.optim.Optimizer methods
################################

def backward(self, loss, retain_graph=False):
def backward(self, loss, inputs=None, retain_graph=False):
assert not (
self._partition_grads and not self.require_grad_sync
), "ZeRO2(partition_grads) and no_sync are not compatible"

if self.mixed_precision_mixin is not None:
loss = self.mixed_precision_mixin.pre_backward(loss)

loss.backward(retain_graph=retain_graph)
loss.backward(inputs=inputs, retain_graph=retain_graph)

if not self.require_grad_sync:
return
Expand All @@ -427,14 +427,19 @@ def backward(self, loss, retain_graph=False):
if self._overlap_communication:
get_accelerator().synchronize()

def backward_by_grad(self, tensor, grad):
def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False):
assert not (
self._partition_grads and not self.require_grad_sync
), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"

if self.mixed_precision_mixin is not None:
grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad)
torch.autograd.backward(tensor, grad)
torch.autograd.backward(
tensor,
grad,
inputs=inputs,
retain_graph=retain_graph,
)

if not self.require_grad_sync:
return
Expand Down
27 changes: 20 additions & 7 deletions tests/test_pipeline/test_schedule/test_zerobubble_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import assert_loose_close
Expand Down Expand Up @@ -751,12 +753,13 @@ def run_with_hybridplugin(test_config):
"config",
[
(0, 1, 4, 1, 1),
# (0, 2, 2, 1, 1),
# (0, 2, 1, 2, 1),
# (0, 2, 1, 1, 2),
(1, 2, 2, 1, 1),
(1, 2, 1, 2, 1),
(1, 2, 1, 1, 2),
],
)
def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
test_config = config
stage, ep_size, pp_size, tp_size, sp_size = config
num_microbatches = pp_size
dist.get_world_size()
Expand Down Expand Up @@ -865,16 +868,23 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
)
# stage 0 chunk 0
parallel_output = None
if rank == dist.get_process_group_ranks(plugin.pp_group)[0]:
if (
booster.plugin.stage_manager.is_first_stage(ignore_chunk=True)
and rank == dist.get_process_group_ranks(plugin.pp_group)[0]
):
parallel_output = sharded_output["loss"]
else:
parallel_output = torch.tensor(12345.0, device="cuda")
# broadcast along pp axis
dist.broadcast(parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[0], group=plugin.pp_group)

else:
# for test without pp
parallel_output = parallel_model(inputs_embeds=input_embeddings.to(dtype)).last_hidden_state.mean()
parallel_optimizer.backward(parallel_output)
parallel_optimizer.step()
parallel_optimizer.zero_grad()
# dist.all_reduce(parallel_output, group=plugin.dp_group)
dist.all_reduce(parallel_output, group=plugin.dp_group)

# ===================================================================================
# run normal model with all dp(different) inputs
Expand All @@ -891,8 +901,11 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
p.grad /= dp_size
torch_optimizer.step()
torch_optimizer.zero_grad()
if rank == dist.get_process_group_ranks(plugin.pp_group)[0]:
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
print(f"rank {dist.get_rank()} config {test_config} test passed")
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()


def run_dist(rank, world_size, port):
Expand Down

0 comments on commit 5c8bbf6

Please sign in to comment.