Skip to content

Commit

Permalink
Merge branch 'master' into reduce_coalesced_fetch_bubble
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Jan 6, 2025
2 parents 99de70e + 0dbbb70 commit 1517d71
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 14 deletions.
3 changes: 2 additions & 1 deletion deepspeed/runtime/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# DeepSpeed Team

import torch
from deepspeed.utils.torch import required_torch_version

try:
from torch.compiler import is_compiling as torch_is_compiling
Expand All @@ -16,7 +17,7 @@


def is_compile_supported():
return hasattr(torch, "compiler") and hasattr(torch.nn.Module, "compile")
return required_torch_version(min_version=2.1)


def disable(func):
Expand Down
14 changes: 12 additions & 2 deletions deepspeed/runtime/pipe/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def forward(self, inputs):
partition_method (str, optional): The method upon which the layers are partitioned. Defaults to 'parameters'.
activation_checkpoint_interval (int, optional): The granularity activation checkpointing in terms of number of layers. 0 disables activation checkpointing.
activation_checkpoint_func (callable, optional): The function to use for activation checkpointing. Defaults to ``deepspeed.checkpointing.checkpoint``.
checkpointable_layers(list, optional): Checkpointable layers may not be checkpointed. Defaults to None which does not additional filtering.
checkpointable_layers (list[str], optional): List of layer class names that are eligible for checkpointing. For GPT models,
ParallelTransformerLayerPipe is always checkpointed regardless of this list. If None, all layers with parameters are
considered checkpointable. Defaults to None.
dynamic_shape: Allows dynamic shapes of inputs. This might have a performance impact.
"""

Expand Down Expand Up @@ -650,9 +652,17 @@ def _is_checkpointable(self, funcs):
# because only non_reentrant_checkpoint can accept inputs with requires_grad=False
# otherwise, the backward of the embedding layer won't receive gradients.
if self.__class__.__name__ in ('GPTModelPipe', 'GPT2ModelPipe'):
return all('ParallelTransformerLayerPipe' in f.__class__.__name__ for f in funcs)
# For GPT models, checkpoint both transformer layers and any additional
# layers specified in checkpointable_layers (if provided)
return all('ParallelTransformerLayerPipe' in f.__class__.__name__ or (
self.checkpointable_layers is not None and f.__class__.__name__ in self.checkpointable_layers)
for f in funcs)

if self.checkpointable_layers is not None:
# For non-GPT models, only checkpoint layers specified in checkpointable_layers
return all(f.__class__.__name__ in self.checkpointable_layers for f in funcs)

# Default behavior: checkpoint any layer that has parameters
params = [f.parameters() for f in funcs if isinstance(f, torch.nn.Module)]
return any(len(list(p)) > 0 for p in params)

Expand Down
7 changes: 2 additions & 5 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from deepspeed.runtime.base_optimizer import ZeROOptimizer
from deepspeed.utils import logger
from deepspeed.utils.torch import register_grad_hook
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce, all_to_all_loco_quant_reduce
from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item
Expand Down Expand Up @@ -1159,7 +1160,6 @@ def overlapping_partition_gradients_reduce_epilogue(self):

def create_reduce_and_remove_grad_hooks(self):
print_rank_0(f'[Begin] Create gradient reduction hooks')
self.grad_accs = []
self.leaf_parameters = defaultdict(list)
for i, param_group in enumerate(self.fp16_groups):
for param in param_group:
Expand All @@ -1172,15 +1172,12 @@ def create_reduce_and_remove_grad_hooks(self):

#print(f"After all gather {param.device}, {param.shape}")
def wrapper(param):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]

@instrument_w_nvtx
def reduce_partition_and_remove_grads(*notneeded):
self.reduce_ready_partitions_and_remove_grads(param)

self._grad_acc_hooks.append(grad_acc.register_hook(reduce_partition_and_remove_grads))
self.grad_accs.append(grad_acc)
self._grad_acc_hooks.append(register_grad_hook(param, reduce_partition_and_remove_grads))

#print(f"param grad fn {param.expand_as(param).grad_fn}")
if z3_leaf_parameter(param):
Expand Down
9 changes: 9 additions & 0 deletions deepspeed/utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,12 @@ def required_torch_version(min_version=None, max_version=None):
return False

return True


def register_grad_hook(param, hook):
if required_torch_version(min_version=2.1):
return param.register_post_accumulate_grad_hook(hook)
else:
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
return grad_acc.register_hook(hook)
2 changes: 0 additions & 2 deletions tests/unit/ops/transformer/inference/test_bias_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)

torch_minor_version = None


def run_bias_add_reference(activations, bias):
return activations + bias
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/ops/transformer/inference/test_bias_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from deepspeed.ops.op_builder import InferenceBuilder
from deepspeed.ops.transformer import DeepSpeedInferenceConfig
from deepspeed.ops.transformer.inference.op_binding.bias_gelu import BiasGeluOp
from deepspeed.utils.torch import required_torch_version
from .inference_test_utils import allclose, get_dtypes
from packaging import version as pkg_version

if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)
Expand All @@ -34,7 +34,7 @@ def run_bias_gelu_ds(activations, bias):
@pytest.mark.parametrize("channels", [512, 1232, 4096])
@pytest.mark.parametrize("dtype", get_dtypes())
def test_bias_gelu(batch, sequence, channels, dtype):
if pkg_version.parse(torch.__version__) < pkg_version.parse("1.12"):
if not required_torch_version(min_version=1.12):
pytest.skip("gelu implementation matches only after torch 1.12")

activations_ds = torch.randn((batch, sequence, channels), dtype=dtype, device=get_accelerator().device_name())
Expand Down
2 changes: 0 additions & 2 deletions tests/unit/ops/transformer/inference/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)

inference_module = None


def allclose(x, y):
assert x.dtype == y.dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
import torch
import deepspeed
from deepspeed.pipe import PipelineModule, LayerSpec
from deepspeed.accelerator import get_accelerator
from copy import deepcopy
from unit.common import DistributedTest
Expand Down Expand Up @@ -259,3 +260,52 @@ def test_ckpt_non_tensor_output_ordering(self, non_tensor_output):
else:
ordering += [torch.is_tensor(non_tensor_output)]
_test_activation_checkpoint_ordering(module, ordering, inputs)


class TestCheckpointableLayersConfig(DistributedTest):
world_size = 1

def test_gpt2_checkpointable_layers(self):
if get_accelerator().device_name() == "cpu":
pytest.skip("CPU accelerator does not support this test yet")

# Create a simple topology for testing
from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology
topo = PipeModelDataParallelTopology(num_pp=1, num_mp=1, num_dp=1)

# Create test classes that we want to checkpoint
class TestTransformerLayer(torch.nn.Module):

def forward(self, x):
return x

class ParallelTransformerLayerPipe(TestTransformerLayer):
pass

class GMLPBlock(TestTransformerLayer):
pass

# Create a mock GPT2 model with different layer types
class TestGPT2ModelPipe(PipelineModule):

def __init__(self):
self.layers_spec = [
LayerSpec(ParallelTransformerLayerPipe),
LayerSpec(GMLPBlock),
LayerSpec(torch.nn.Linear, 10, 10), # Should not be checkpointed
]

super().__init__(layers=self.layers_spec,
topology=topo,
checkpointable_layers=["GMLPBlock", "ParallelTransformerLayerPipe"])

model = TestGPT2ModelPipe()
model.to(get_accelerator().device_name())

# Build layers manually for testing
layers = [spec.build() for spec in model.layers_spec]

# Test that _is_checkpointable returns correct values
assert model._is_checkpointable([layers[0]]) == True # ParallelTransformerLayerPipe
assert model._is_checkpointable([layers[1]]) == True # GMLPBlock
assert model._is_checkpointable([layers[2]]) == False # Linear layer

0 comments on commit 1517d71

Please sign in to comment.