diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index c627846b743c..26196ff37ac4 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -5,6 +5,8 @@ from types import MethodType from collections import OrderedDict +from functools import reduce +from operator import mul import torch from deepspeed import comm as dist @@ -40,6 +42,9 @@ PIPE_RECV_INPUT_TIMER = 'pipe_recv_input' PIPE_RECV_GRAD_TIMER = 'pipe_recv_grad' +# The buffer size to store the meta data for each tensor. +TENSOR_META_SIZE = 256 + def is_even(number): return number % 2 == 0 @@ -179,6 +184,7 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): } self.pipe_recv_buf = None self.grad_layer = None + self._grad_layer_buf = [] self.meta_buffer = None @@ -250,6 +256,8 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): self.timers(STEP_MICRO_TIMER).start() self.timers(STEP_MICRO_TIMER).stop() + self.dynamic_shape = self.module.dynamic_shape + def set_has_attention_mask(self, value): assert isinstance(value, bool) self.has_attention_mask = value @@ -318,6 +326,7 @@ def reset_activation_shape(self): self.first_output_send = True self.pipe_recv_buf = None self.grad_layer = None + self._grad_layer_buf = [] self.meta_buffer = None self.pipe_partition_input_meta_cache = None @@ -926,51 +935,38 @@ def _send_tensor_meta(self, buffer, recv_stage): * ndims * shape """ - send_bytes = 0 + meta_buffer = torch.empty(TENSOR_META_SIZE, dtype=torch.int32, device=self.device) if isinstance(buffer, torch.Tensor): - type_tensor = torch.LongTensor(data=[0]).to(self.device) - p2p.send(type_tensor, recv_stage) - send_shape = torch.LongTensor(data=buffer.size()).to(self.device) - send_ndims = torch.LongTensor(data=[len(buffer.size())]).to(self.device) - p2p.send(send_ndims, recv_stage) - p2p.send(send_shape, recv_stage) - send_bytes += _tensor_bytes(buffer) - elif isinstance(buffer, list): - assert (False) - type_tensor = torch.LongTensor(data=[1]).to(self.device) - p2p.send(type_tensor, recv_stage) - count_tensor = torch.LongTensor(data=[len(buffer)]).to(self.device) - p2p.send(count_tensor, recv_stage) - for tensor in buffer: - assert isinstance(tensor, torch.Tensor) - send_shape = torch.LongTensor(data=tensor.size()).to(self.device) - send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device) - p2p.send(send_ndims, recv_stage) - p2p.send(send_shape, recv_stage) - send_bytes += _tensor_bytes(tensor) + meta_buf_list = [ + 0, # type of data (0: tensor, 1: list (unused), 2: tuple) + self.DTYPE_TO_ID[buffer.dtype], # dtype + len(buffer.size()) # ndims + ] + meta_buf_list.extend(buffer.size()) + assert len( + meta_buf_list + ) <= TENSOR_META_SIZE, f"Buffer for metadata is too small. Current buffer size: {TENSOR_META_SIZE} but required {len(meta_buf_list)}" + meta_buffer[:len(meta_buf_list)].copy_(torch.tensor(meta_buf_list, dtype=torch.int32)) + p2p.send(meta_buffer, recv_stage) + elif isinstance(buffer, tuple): - type_tensor = torch.LongTensor(data=[2]).to(self.device) - p2p.send(type_tensor, recv_stage) - count_tensor = torch.LongTensor(data=[len(buffer)]).to(self.device) - p2p.send(count_tensor, recv_stage) - for idx, tensor in enumerate(buffer): + meta_buf_list = [ + 2, # type of data (0: tensor, 1: list (unused), 2: tuple) + len(buffer) # num_tensors + ] + + for tensor in buffer: assert isinstance(tensor, torch.Tensor) - send_shape = torch.LongTensor(data=tensor.size()).to(self.device) - send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device) - send_dtype = torch.LongTensor(data=[self.DTYPE_TO_ID[tensor.dtype]]).to(self.device) - p2p.send(send_dtype, recv_stage) - p2p.send(send_ndims, recv_stage) - p2p.send(send_shape, recv_stage) - # Useful for performance debugging. - ''' - new_bytes = _tensor_bytes(tensor) - send_bytes += _tensor_bytes(tensor) - # Useful for performance debugging. - if self.grid.data_parallel_id == 0: - print( - f'STAGE={self.stage_id} pipe-send-volume[{idx}]: shape={send_shape} {new_bytes/1024**2:0.2f}MB' - ) - ''' + meta_buf_list.append(self.DTYPE_TO_ID[tensor.dtype]) + meta_buf_list.append(len(tensor.size())) + meta_buf_list.extend(tensor.size()) + + assert len( + meta_buf_list + ) <= TENSOR_META_SIZE, f"Buffer for metadata is too small. Current buffer size: {TENSOR_META_SIZE} but required {len(meta_buf_list)}" + meta_buffer[:len(meta_buf_list)].copy_(torch.tensor(meta_buf_list, dtype=torch.int32)) + p2p.send(meta_buffer, recv_stage) + else: raise NotImplementedError(f'Could not send meta type {type(buffer)}') @@ -983,49 +979,35 @@ def _send_tensor_meta(self, buffer, recv_stage): def _recv_tensor_meta(self, send_stage): """Receive metadata about upcoming p2p transfers and return allocated buffers. - Metadata is communicated in this order: - * type (0: tensor, 1: list) - * num_tensors if type=list - foreach tensor in buffer: - * ndims - * shape - Returns: Allocated buffer for receiving from send_stage. """ + buffer = torch.empty(TENSOR_META_SIZE, dtype=torch.int32, device=self.device) + p2p.recv(buffer, send_stage) - type_tensor = torch.LongTensor(data=[0]).to(self.device) - p2p.recv(type_tensor, send_stage) - recv_type = type_tensor.item() + recv_type = buffer[0].item() # A single tensor will be sent. if recv_type == 0: - recv_ndims = torch.LongTensor(data=[0]).to(self.device) - p2p.recv(recv_ndims, send_stage) - recv_ndims = recv_ndims.item() - recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device) - p2p.recv(recv_shape, send_stage) - recv_shape = recv_shape.tolist() - return self._allocate_buffer(recv_shape, num_buffers=1)[0] - - # List or tuple of tensors + recv_dtype = self.ID_TO_DTYPE[buffer[1].item()] + recv_ndims = buffer[2].item() + recv_shape = buffer[3:3 + recv_ndims].tolist() + return self._allocate_or_extend_buffers(0, recv_shape, recv_dtype) + + # List or tuple of tensors (recv_type == 1 (list) is currently unused) elif recv_type == 1 or recv_type == 2: - count_tensor = torch.LongTensor(data=[0]).to(self.device) - p2p.recv(count_tensor, send_stage) - num_tensors = count_tensor.item() - recv_shapes_and_dtypes = [] + num_tensors = buffer[1].item() + + buffers = [] + offset = 2 for idx in range(num_tensors): - recv_dtype = torch.LongTensor(data=[0]).to(self.device) - p2p.recv(recv_dtype, send_stage) - recv_dtype = self.ID_TO_DTYPE[recv_dtype.item()] - recv_ndims = torch.LongTensor(data=[0]).to(self.device) - p2p.recv(recv_ndims, send_stage) - recv_ndims = recv_ndims.item() - recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device) - p2p.recv(recv_shape, send_stage) - recv_shapes_and_dtypes.append((recv_shape.tolist(), recv_dtype)) - - buffers = self._allocate_buffers(recv_shapes_and_dtypes, num_buffers=1)[0] + recv_dtype = self.ID_TO_DTYPE[buffer[offset].item()] + recv_ndims = buffer[offset + 1].item() + recv_shape = buffer[offset + 2:offset + 2 + recv_ndims].tolist() + offset += 2 + recv_ndims + + buffers.append(self._allocate_or_extend_buffers(idx, recv_shape, recv_dtype)) + # Convert to tuples if requested. if recv_type == 2: buffers = tuple(buffers) @@ -1048,7 +1030,7 @@ def _exec_send_activations(self, buffer_id): outputs[-1] = outputs[-1].half() outputs = tuple(outputs) - if self.first_output_send: + if self.dynamic_shape or self.first_output_send: self.first_output_send = False self._send_tensor_meta(outputs, self.next_stage) @@ -1133,7 +1115,7 @@ def _exec_recv_activations(self, buffer_id): recvd = None # Allocate the buffer if necessary - if self.pipe_recv_buf is None: + if self.dynamic_shape or self.pipe_recv_buf is None: self.pipe_recv_buf = self._recv_tensor_meta(self.prev_stage) if isinstance(self.pipe_recv_buf, torch.Tensor): @@ -1188,10 +1170,9 @@ def _exec_recv_grads(self, buffer_id): self.pipe_buffers['outputs'][buffer_id] = outputs # Allocate gradient if necessary - if self.grad_layer is None: + if self.dynamic_shape or self.grad_layer is None: if isinstance(outputs, torch.Tensor): - s = list(outputs.size()) - self.grad_layer = self._allocate_buffer(s, dtype=outputs.dtype, num_buffers=1)[0] + self.grad_layer = self._allocate_or_extend_buffers(0, list(outputs.size()), outputs.dtype) else: # XXX This is a HACK # When we exchange activations/gradients, the two pipe stages @@ -1213,7 +1194,11 @@ def _exec_recv_grads(self, buffer_id): for t in outputs[2:] if t.is_floating_point()] else: sizes_and_dtypes = [(list(t.size()), t.dtype) for t in outputs if t.is_floating_point()] - self.grad_layer = self._allocate_buffers(sizes_and_dtypes, num_buffers=1)[0] + + self.grad_layer = [ + self._allocate_or_extend_buffers(i, size, dtype) + for i, (size, dtype) in enumerate(sizes_and_dtypes) + ] if isinstance(self.grad_layer, torch.Tensor): p2p.recv(self.grad_layer, self.next_stage) @@ -1294,16 +1279,17 @@ def _allocate_buffer(self, shape, num_buffers=-1, **kwargs): buffers.append(self._allocate_zeros(shape, **kwargs)) return buffers - def _allocate_buffers(self, shapes_and_dtypes, requires_grad=False, num_buffers=-1): - buffers = [] - if num_buffers == -1: - num_buffers = self.num_pipe_buffers - for count in range(num_buffers): - buffer = [] - for shape, dtype in shapes_and_dtypes: - buffer.append(self._allocate_zeros(shape, dtype=dtype, requires_grad=requires_grad)) - buffers.append(buffer) - return buffers + def _allocate_or_extend_buffers(self, idx, shape, dtype): + numel = reduce(mul, shape) if len(shape) > 0 else 1 + if len(self._grad_layer_buf) <= idx or self._grad_layer_buf[idx].numel() < numel: + new_buf = self._allocate_buffer(shape, dtype=dtype, num_buffers=1)[0] + if len(self._grad_layer_buf) <= idx: + self._grad_layer_buf.append(new_buf) + else: + self._grad_layer_buf[idx] = new_buf + return self._grad_layer_buf[idx] + else: + return self._grad_layer_buf[idx].flatten()[:numel].view(shape) def forward(self, *args, **kwargs): """Disabled for pipeline parallel training. See ``train_batch()``. """ diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index 3c25cbee66ec..31fec30be788 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -117,6 +117,7 @@ def forward(self, inputs): activation_checkpoint_interval (int, optional): The granularity activation checkpointing in terms of number of layers. 0 disables activation checkpointing. activation_checkpoint_func (callable, optional): The function to use for activation checkpointing. Defaults to ``deepspeed.checkpointing.checkpoint``. checkpointable_layers(list, optional): Checkpointable layers may not be checkpointed. Defaults to None which does not additional filtering. + dynamic_shape: Allows dynamic shapes of inputs. This might have a performance impact. """ def __init__(self, @@ -130,7 +131,8 @@ def __init__(self, partition_method='parameters', activation_checkpoint_interval=0, activation_checkpoint_func=checkpointing.checkpoint, - checkpointable_layers=None): + checkpointable_layers=None, + dynamic_shape=False): super().__init__() @@ -213,6 +215,8 @@ def __init__(self, self.tied_comms = self._index_tied_modules() self._synchronize_tied_weights() + self.dynamic_shape = dynamic_shape + def _precompute_checkpointable_values(self): if self.activation_checkpoint_interval > 0 and self.is_checkpointable_results_interval != self.activation_checkpoint_interval: num_layers = len(self.forward_funcs) diff --git a/tests/unit/alexnet_model.py b/tests/unit/alexnet_model.py index 25256d376eeb..dfab28aa7477 100644 --- a/tests/unit/alexnet_model.py +++ b/tests/unit/alexnet_model.py @@ -14,6 +14,7 @@ from deepspeed.utils.torch import required_torch_version from deepspeed.accelerator import get_accelerator from deepspeed.runtime.pipe.module import PipelineModule, LayerSpec +from .util import no_child_process_in_deepspeed_io class AlexNet(nn.Module): @@ -125,22 +126,11 @@ def train_cifar(model, config, num_steps=400, average_dp_losses=True, fp16=True, trainset = cifar_trainset(fp16=fp16) config['local_rank'] = dist.get_rank() - # deepspeed_io defaults to creating a dataloader that uses a - # multiprocessing pool. Our tests use pools and we cannot nest pools in - # python. Therefore we're injecting this kwarg to ensure that no pools - # are used in the dataloader. - old_method = deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io - - def new_method(*args, **kwargs): - kwargs["num_local_io_workers"] = 0 - return old_method(*args, **kwargs) - - deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io = new_method - - engine, _, _, _ = deepspeed.initialize(config=config, - model=model, - model_parameters=[p for p in model.parameters()], - training_data=trainset) + with no_child_process_in_deepspeed_io(): + engine, _, _, _ = deepspeed.initialize(config=config, + model=model, + model_parameters=[p for p in model.parameters()], + training_data=trainset) losses = [] for step in range(num_steps): diff --git a/tests/unit/runtime/pipe/test_pipe.py b/tests/unit/runtime/pipe/test_pipe.py index 88e26290b650..f198762c5fcc 100644 --- a/tests/unit/runtime/pipe/test_pipe.py +++ b/tests/unit/runtime/pipe/test_pipe.py @@ -7,12 +7,15 @@ import torch.nn as nn import pytest +import torch + +import deepspeed import deepspeed.comm as dist from deepspeed.runtime.pipe.topology import PipeDataParallelTopology from deepspeed.runtime.pipe.module import PipelineModule from unit.alexnet_model import AlexNetPipe, train_cifar from unit.common import DistributedTest -from unit.util import skip_on_arch +from unit.util import skip_on_arch, no_child_process_in_deepspeed_io PipeTopo = PipeDataParallelTopology @@ -155,3 +158,95 @@ def test_pipe_use_reentrant(self, topo_config): # the following check could passed on higher version docker: nvcr.io/nvidia/pytorch:23.07-py3(torch2.1.0 cuda12.1) # Check if models have same weights after training # self._check_model_params_equal(base_model, test_model) + + +class DynamicShapeTestLayer(nn.Module): + + def __init__(self, hidden_size): + super().__init__() + self.fc = nn.Linear(hidden_size, hidden_size) + self.shapes = set() + + def forward(self, x): + self.shapes.add(x.shape) + y = self.fc(x) + return y + + +class DynamicShapeTestModel(nn.Module): + + def __init__(self, n_layers, hidden_size): + super().__init__() + self.layers = nn.ModuleList([DynamicShapeTestLayer(hidden_size) for _ in range(n_layers)]) + + +@pytest.mark.parametrize('topo_config', [ + { + "num_pp": 1, + "num_dp": 4 + }, + { + "num_pp": 2, + "num_dp": 2 + }, + { + "num_pp": 4, + "num_dp": 1 + }, +]) +class TestPipeDynamicShape(DistributedTest): + world_size = 4 + + def test_pipe_base(self, topo_config): + """This test checks if the pipeline engine can handle dynamic shapes correctly. + We pass inputs of different shapes to the pipeline engine. + """ + + n_iter = 10 + n_layers = 4 + n_samples = 1024 + batch_size = 4 + channel_dims = [8, 16, 32, 64] + hidden_size = 16 + + topo = PipeTopo(**topo_config) + + model = DynamicShapeTestModel(n_layers, hidden_size) + model = PipelineModule(layers=model.layers, topology=topo, loss_fn=nn.MSELoss(), dynamic_shape=True) + + # Each batch has different channel dim but we use the same channel dim in the same batch + xs = [ + torch.randn(channel_dims[(i // batch_size) % len(channel_dims)], hidden_size, dtype=torch.float32) + for i in range(n_samples) + ] + ys = [torch.randn_like(x) for x in xs] + + class CustomDataset(torch.utils.data.Dataset): + + def __init__(self, xs, ys): + self.xs = xs + self.ys = ys + + def __len__(self): + return len(self.xs) + + def __getitem__(self, idx): + return self.xs[idx], self.ys[idx] + + dataset = CustomDataset(xs, ys) + + config_dict["train_batch_size"] = batch_size + + with no_child_process_in_deepspeed_io(): + engine, _, _, _ = deepspeed.initialize(config=config_dict, + model=model, + model_parameters=[p for p in model.parameters()], + training_data=dataset) + + for _ in range(n_iter): + _ = engine.train_batch() + + # Check if all layers have seen different shapes + for layer in model.modules(): + if isinstance(layer, DynamicShapeTestLayer): + assert len(layer.shapes) > 1 diff --git a/tests/unit/util.py b/tests/unit/util.py index feec326ede6c..dba29ed27a4c 100644 --- a/tests/unit/util.py +++ b/tests/unit/util.py @@ -5,6 +5,8 @@ import pytest import torch + +import deepspeed from deepspeed.accelerator import get_accelerator, is_current_accelerator_supported from deepspeed.git_version_info import torch_info @@ -67,3 +69,22 @@ def required_amp_check(): return False else: return True + + +class no_child_process_in_deepspeed_io: + + def __enter__(self): + # deepspeed_io defaults to creating a dataloader that uses a + # multiprocessing pool. Our tests use pools and we cannot nest pools in + # python. Therefore we're injecting this kwarg to ensure that no pools + # are used in the dataloader. + self.old_method = deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io + + def new_method(*args, **kwargs): + kwargs["num_local_io_workers"] = 0 + return self.old_method(*args, **kwargs) + + deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io = new_method + + def __exit__(self, *_): + deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io = self.old_method