Skip to content

Commit

Permalink
Merge branch 'master' into bf16-optimizer-immediate-grad-update
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase authored Dec 18, 2023
2 parents 1272438 + 4559dad commit beea502
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs):

self.first_output_send = True
self.first_gradient_send = True
self.pipe_partition_input_meta_cache = None
self.pipe_partition_output_meta_cache = None
self.pipe_partition_grad_meta_cache = None
self.grad_partition_grad_layer_meta_cache = None

#stores the loss for the current micro batch being processed
self.loss = torch.tensor(0.0).to(self.device)
Expand Down Expand Up @@ -309,6 +313,11 @@ def reset_activation_shape(self):
self.grad_layer = None
self.meta_buffer = None

self.pipe_partition_input_meta_cache = None
self.pipe_partition_output_meta_cache = None
self.pipe_partition_grad_meta_cache = None
self.grad_partition_grad_layer_meta_cache = None

def train_batch(self, data_iter=None):
"""Progress the pipeline to train the next batch of data. The engine will ingest
``self.train_batch_size()`` total samples collectively across all workers.
Expand Down Expand Up @@ -641,7 +650,9 @@ def _exec_forward_pass(self, buffer_id):

# collect the partitioned input from the previous stage
if self.is_pipe_partitioned and not self.is_first_stage():
part_input = PartitionedTensor.from_meta(meta=inputs[0],
if self.pipe_partition_input_meta_cache is None:
self.pipe_partition_input_meta_cache = inputs[0].to('cpu')
part_input = PartitionedTensor.from_meta(meta=self.pipe_partition_input_meta_cache,
local_part=inputs[1],
group=self.grid.get_slice_parallel_group())

Expand Down Expand Up @@ -732,7 +743,9 @@ def _exec_backward_pass(self, buffer_id):
# careful to also restore the computational graph of the tensors we partitioned.
if self.is_pipe_partitioned:
if self.is_grad_partitioned:
part_output = PartitionedTensor.from_meta(meta=outputs[0],
if self.pipe_partition_output_meta_cache is None:
self.pipe_partition_output_meta_cache = outputs[0].to('cpu')
part_output = PartitionedTensor.from_meta(meta=self.pipe_partition_output_meta_cache,
local_part=outputs[1],
group=self.grid.get_slice_parallel_group())
self.pipe_buffers['output_tensors'][buffer_id].data = part_output.full()
Expand All @@ -745,7 +758,9 @@ def _exec_backward_pass(self, buffer_id):
grad_tensors = self.grad_layer
if self.is_grad_partitioned:
#print(f'RANK={self.global_rank} BEFORE-BWD restoring grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}')
part_grad = PartitionedTensor.from_meta(meta=self.grad_layer[0],
if self.grad_partition_grad_layer_meta_cache is None:
self.grad_partition_grad_layer_meta_cache = self.grad_layer[0].to('cpu')
part_grad = PartitionedTensor.from_meta(meta=self.grad_partition_grad_layer_meta_cache,
local_part=self.grad_layer[1],
group=self.grid.get_slice_parallel_group())
grad_tensors = (part_grad.full(), *grad_tensors[2:])
Expand Down Expand Up @@ -1089,7 +1104,9 @@ def _exec_recv_grads(self, buffer_id):
# XXX these shapes are hardcoded for Megatron
# Restore partitioned output if it was partitioned and we are sending full gradients
if self.is_pipe_partitioned and not self.is_grad_partitioned:
part_output = PartitionedTensor.from_meta(meta=outputs[0],
if self.pipe_partition_grad_meta_cache is None:
self.pipe_partition_grad_meta_cache = outputs[0].to('cpu')
part_output = PartitionedTensor.from_meta(meta=self.pipe_partition_grad_meta_cache,
local_part=outputs[1],
group=self.grid.get_slice_parallel_group())
outputs[0].data = part_output.full()
Expand Down

0 comments on commit beea502

Please sign in to comment.