Skip to content

Commit

Permalink
Enable dynamic shapes for pipeline parallel engine inputs (#5481)
Browse files Browse the repository at this point in the history
This PR enables dynamic shapes for inputs to pipeline parallel (PP)
engine.

Currently PP engine checks tensor shapes and allocates communication
buffer at the first forward/backward passes. This causes a tensor shape
mismatch error when input tensor shapes changed.

This PR adds an option to check tensor shapes at every iteration and
allocate buffer based on the shapes. As shown below, you can enable this
feature by passing `dynamic_shape=True` to `PipelineModule`.
Note that this might have a performance impact and the option is set to
False as default.

```python
model = PipelineModule(
...
   dynamic_shape=True
)
```

This will increase the overhead of buffer allocation and communication
for tensor metadata. To mitigate the overhead, this PR also includes
these improvements:
- Consolidate multiple communication calls to send/recv tensor shapes
9f96ad4
- Reuse (extend) communication buffer instead of creating a new one
b3c0750

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
tohtana and tjruwase authored Aug 16, 2024
1 parent 4d4ff0e commit 1ab1928
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 110 deletions.
170 changes: 78 additions & 92 deletions deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from types import MethodType
from collections import OrderedDict
from functools import reduce
from operator import mul

import torch
from deepspeed import comm as dist
Expand Down Expand Up @@ -40,6 +42,9 @@
PIPE_RECV_INPUT_TIMER = 'pipe_recv_input'
PIPE_RECV_GRAD_TIMER = 'pipe_recv_grad'

# The buffer size to store the meta data for each tensor.
TENSOR_META_SIZE = 256


def is_even(number):
return number % 2 == 0
Expand Down Expand Up @@ -179,6 +184,7 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs):
}
self.pipe_recv_buf = None
self.grad_layer = None
self._grad_layer_buf = []

self.meta_buffer = None

Expand Down Expand Up @@ -250,6 +256,8 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs):
self.timers(STEP_MICRO_TIMER).start()
self.timers(STEP_MICRO_TIMER).stop()

self.dynamic_shape = self.module.dynamic_shape

def set_has_attention_mask(self, value):
assert isinstance(value, bool)
self.has_attention_mask = value
Expand Down Expand Up @@ -318,6 +326,7 @@ def reset_activation_shape(self):
self.first_output_send = True
self.pipe_recv_buf = None
self.grad_layer = None
self._grad_layer_buf = []
self.meta_buffer = None

self.pipe_partition_input_meta_cache = None
Expand Down Expand Up @@ -926,51 +935,38 @@ def _send_tensor_meta(self, buffer, recv_stage):
* ndims
* shape
"""
send_bytes = 0
meta_buffer = torch.empty(TENSOR_META_SIZE, dtype=torch.int32, device=self.device)
if isinstance(buffer, torch.Tensor):
type_tensor = torch.LongTensor(data=[0]).to(self.device)
p2p.send(type_tensor, recv_stage)
send_shape = torch.LongTensor(data=buffer.size()).to(self.device)
send_ndims = torch.LongTensor(data=[len(buffer.size())]).to(self.device)
p2p.send(send_ndims, recv_stage)
p2p.send(send_shape, recv_stage)
send_bytes += _tensor_bytes(buffer)
elif isinstance(buffer, list):
assert (False)
type_tensor = torch.LongTensor(data=[1]).to(self.device)
p2p.send(type_tensor, recv_stage)
count_tensor = torch.LongTensor(data=[len(buffer)]).to(self.device)
p2p.send(count_tensor, recv_stage)
for tensor in buffer:
assert isinstance(tensor, torch.Tensor)
send_shape = torch.LongTensor(data=tensor.size()).to(self.device)
send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device)
p2p.send(send_ndims, recv_stage)
p2p.send(send_shape, recv_stage)
send_bytes += _tensor_bytes(tensor)
meta_buf_list = [
0, # type of data (0: tensor, 1: list (unused), 2: tuple)
self.DTYPE_TO_ID[buffer.dtype], # dtype
len(buffer.size()) # ndims
]
meta_buf_list.extend(buffer.size())
assert len(
meta_buf_list
) <= TENSOR_META_SIZE, f"Buffer for metadata is too small. Current buffer size: {TENSOR_META_SIZE} but required {len(meta_buf_list)}"
meta_buffer[:len(meta_buf_list)].copy_(torch.tensor(meta_buf_list, dtype=torch.int32))
p2p.send(meta_buffer, recv_stage)

elif isinstance(buffer, tuple):
type_tensor = torch.LongTensor(data=[2]).to(self.device)
p2p.send(type_tensor, recv_stage)
count_tensor = torch.LongTensor(data=[len(buffer)]).to(self.device)
p2p.send(count_tensor, recv_stage)
for idx, tensor in enumerate(buffer):
meta_buf_list = [
2, # type of data (0: tensor, 1: list (unused), 2: tuple)
len(buffer) # num_tensors
]

for tensor in buffer:
assert isinstance(tensor, torch.Tensor)
send_shape = torch.LongTensor(data=tensor.size()).to(self.device)
send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device)
send_dtype = torch.LongTensor(data=[self.DTYPE_TO_ID[tensor.dtype]]).to(self.device)
p2p.send(send_dtype, recv_stage)
p2p.send(send_ndims, recv_stage)
p2p.send(send_shape, recv_stage)
# Useful for performance debugging.
'''
new_bytes = _tensor_bytes(tensor)
send_bytes += _tensor_bytes(tensor)
# Useful for performance debugging.
if self.grid.data_parallel_id == 0:
print(
f'STAGE={self.stage_id} pipe-send-volume[{idx}]: shape={send_shape} {new_bytes/1024**2:0.2f}MB'
)
'''
meta_buf_list.append(self.DTYPE_TO_ID[tensor.dtype])
meta_buf_list.append(len(tensor.size()))
meta_buf_list.extend(tensor.size())

assert len(
meta_buf_list
) <= TENSOR_META_SIZE, f"Buffer for metadata is too small. Current buffer size: {TENSOR_META_SIZE} but required {len(meta_buf_list)}"
meta_buffer[:len(meta_buf_list)].copy_(torch.tensor(meta_buf_list, dtype=torch.int32))
p2p.send(meta_buffer, recv_stage)

else:
raise NotImplementedError(f'Could not send meta type {type(buffer)}')

Expand All @@ -983,49 +979,35 @@ def _send_tensor_meta(self, buffer, recv_stage):
def _recv_tensor_meta(self, send_stage):
"""Receive metadata about upcoming p2p transfers and return allocated buffers.
Metadata is communicated in this order:
* type (0: tensor, 1: list)
* num_tensors if type=list
foreach tensor in buffer:
* ndims
* shape
Returns:
Allocated buffer for receiving from send_stage.
"""
buffer = torch.empty(TENSOR_META_SIZE, dtype=torch.int32, device=self.device)
p2p.recv(buffer, send_stage)

type_tensor = torch.LongTensor(data=[0]).to(self.device)
p2p.recv(type_tensor, send_stage)
recv_type = type_tensor.item()
recv_type = buffer[0].item()

# A single tensor will be sent.
if recv_type == 0:
recv_ndims = torch.LongTensor(data=[0]).to(self.device)
p2p.recv(recv_ndims, send_stage)
recv_ndims = recv_ndims.item()
recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device)
p2p.recv(recv_shape, send_stage)
recv_shape = recv_shape.tolist()
return self._allocate_buffer(recv_shape, num_buffers=1)[0]

# List or tuple of tensors
recv_dtype = self.ID_TO_DTYPE[buffer[1].item()]
recv_ndims = buffer[2].item()
recv_shape = buffer[3:3 + recv_ndims].tolist()
return self._allocate_or_extend_buffers(0, recv_shape, recv_dtype)

# List or tuple of tensors (recv_type == 1 (list) is currently unused)
elif recv_type == 1 or recv_type == 2:
count_tensor = torch.LongTensor(data=[0]).to(self.device)
p2p.recv(count_tensor, send_stage)
num_tensors = count_tensor.item()
recv_shapes_and_dtypes = []
num_tensors = buffer[1].item()

buffers = []
offset = 2
for idx in range(num_tensors):
recv_dtype = torch.LongTensor(data=[0]).to(self.device)
p2p.recv(recv_dtype, send_stage)
recv_dtype = self.ID_TO_DTYPE[recv_dtype.item()]
recv_ndims = torch.LongTensor(data=[0]).to(self.device)
p2p.recv(recv_ndims, send_stage)
recv_ndims = recv_ndims.item()
recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device)
p2p.recv(recv_shape, send_stage)
recv_shapes_and_dtypes.append((recv_shape.tolist(), recv_dtype))

buffers = self._allocate_buffers(recv_shapes_and_dtypes, num_buffers=1)[0]
recv_dtype = self.ID_TO_DTYPE[buffer[offset].item()]
recv_ndims = buffer[offset + 1].item()
recv_shape = buffer[offset + 2:offset + 2 + recv_ndims].tolist()
offset += 2 + recv_ndims

buffers.append(self._allocate_or_extend_buffers(idx, recv_shape, recv_dtype))

# Convert to tuples if requested.
if recv_type == 2:
buffers = tuple(buffers)
Expand All @@ -1048,7 +1030,7 @@ def _exec_send_activations(self, buffer_id):
outputs[-1] = outputs[-1].half()
outputs = tuple(outputs)

if self.first_output_send:
if self.dynamic_shape or self.first_output_send:
self.first_output_send = False
self._send_tensor_meta(outputs, self.next_stage)

Expand Down Expand Up @@ -1133,7 +1115,7 @@ def _exec_recv_activations(self, buffer_id):
recvd = None

# Allocate the buffer if necessary
if self.pipe_recv_buf is None:
if self.dynamic_shape or self.pipe_recv_buf is None:
self.pipe_recv_buf = self._recv_tensor_meta(self.prev_stage)

if isinstance(self.pipe_recv_buf, torch.Tensor):
Expand Down Expand Up @@ -1188,10 +1170,9 @@ def _exec_recv_grads(self, buffer_id):
self.pipe_buffers['outputs'][buffer_id] = outputs

# Allocate gradient if necessary
if self.grad_layer is None:
if self.dynamic_shape or self.grad_layer is None:
if isinstance(outputs, torch.Tensor):
s = list(outputs.size())
self.grad_layer = self._allocate_buffer(s, dtype=outputs.dtype, num_buffers=1)[0]
self.grad_layer = self._allocate_or_extend_buffers(0, list(outputs.size()), outputs.dtype)
else:
# XXX This is a HACK
# When we exchange activations/gradients, the two pipe stages
Expand All @@ -1213,7 +1194,11 @@ def _exec_recv_grads(self, buffer_id):
for t in outputs[2:] if t.is_floating_point()]
else:
sizes_and_dtypes = [(list(t.size()), t.dtype) for t in outputs if t.is_floating_point()]
self.grad_layer = self._allocate_buffers(sizes_and_dtypes, num_buffers=1)[0]

self.grad_layer = [
self._allocate_or_extend_buffers(i, size, dtype)
for i, (size, dtype) in enumerate(sizes_and_dtypes)
]

if isinstance(self.grad_layer, torch.Tensor):
p2p.recv(self.grad_layer, self.next_stage)
Expand Down Expand Up @@ -1294,16 +1279,17 @@ def _allocate_buffer(self, shape, num_buffers=-1, **kwargs):
buffers.append(self._allocate_zeros(shape, **kwargs))
return buffers

def _allocate_buffers(self, shapes_and_dtypes, requires_grad=False, num_buffers=-1):
buffers = []
if num_buffers == -1:
num_buffers = self.num_pipe_buffers
for count in range(num_buffers):
buffer = []
for shape, dtype in shapes_and_dtypes:
buffer.append(self._allocate_zeros(shape, dtype=dtype, requires_grad=requires_grad))
buffers.append(buffer)
return buffers
def _allocate_or_extend_buffers(self, idx, shape, dtype):
numel = reduce(mul, shape) if len(shape) > 0 else 1
if len(self._grad_layer_buf) <= idx or self._grad_layer_buf[idx].numel() < numel:
new_buf = self._allocate_buffer(shape, dtype=dtype, num_buffers=1)[0]
if len(self._grad_layer_buf) <= idx:
self._grad_layer_buf.append(new_buf)
else:
self._grad_layer_buf[idx] = new_buf
return self._grad_layer_buf[idx]
else:
return self._grad_layer_buf[idx].flatten()[:numel].view(shape)

def forward(self, *args, **kwargs):
"""Disabled for pipeline parallel training. See ``train_batch()``. """
Expand Down
6 changes: 5 additions & 1 deletion deepspeed/runtime/pipe/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def forward(self, inputs):
activation_checkpoint_interval (int, optional): The granularity activation checkpointing in terms of number of layers. 0 disables activation checkpointing.
activation_checkpoint_func (callable, optional): The function to use for activation checkpointing. Defaults to ``deepspeed.checkpointing.checkpoint``.
checkpointable_layers(list, optional): Checkpointable layers may not be checkpointed. Defaults to None which does not additional filtering.
dynamic_shape: Allows dynamic shapes of inputs. This might have a performance impact.
"""

def __init__(self,
Expand All @@ -130,7 +131,8 @@ def __init__(self,
partition_method='parameters',
activation_checkpoint_interval=0,
activation_checkpoint_func=checkpointing.checkpoint,
checkpointable_layers=None):
checkpointable_layers=None,
dynamic_shape=False):

super().__init__()

Expand Down Expand Up @@ -213,6 +215,8 @@ def __init__(self,
self.tied_comms = self._index_tied_modules()
self._synchronize_tied_weights()

self.dynamic_shape = dynamic_shape

def _precompute_checkpointable_values(self):
if self.activation_checkpoint_interval > 0 and self.is_checkpointable_results_interval != self.activation_checkpoint_interval:
num_layers = len(self.forward_funcs)
Expand Down
22 changes: 6 additions & 16 deletions tests/unit/alexnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from deepspeed.utils.torch import required_torch_version
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.pipe.module import PipelineModule, LayerSpec
from .util import no_child_process_in_deepspeed_io


class AlexNet(nn.Module):
Expand Down Expand Up @@ -125,22 +126,11 @@ def train_cifar(model, config, num_steps=400, average_dp_losses=True, fp16=True,
trainset = cifar_trainset(fp16=fp16)
config['local_rank'] = dist.get_rank()

# deepspeed_io defaults to creating a dataloader that uses a
# multiprocessing pool. Our tests use pools and we cannot nest pools in
# python. Therefore we're injecting this kwarg to ensure that no pools
# are used in the dataloader.
old_method = deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io

def new_method(*args, **kwargs):
kwargs["num_local_io_workers"] = 0
return old_method(*args, **kwargs)

deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io = new_method

engine, _, _, _ = deepspeed.initialize(config=config,
model=model,
model_parameters=[p for p in model.parameters()],
training_data=trainset)
with no_child_process_in_deepspeed_io():
engine, _, _, _ = deepspeed.initialize(config=config,
model=model,
model_parameters=[p for p in model.parameters()],
training_data=trainset)

losses = []
for step in range(num_steps):
Expand Down
Loading

0 comments on commit 1ab1928

Please sign in to comment.