diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 799a9345b8efa..4aab63827a343 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -182,6 +182,10 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): self.first_output_send = True self.first_gradient_send = True + self.pipe_partition_input_meta_cache = None + self.pipe_partition_output_meta_cache = None + self.pipe_partition_grad_meta_cache = None + self.grad_partition_grad_layer_meta_cache = None #stores the loss for the current micro batch being processed self.loss = torch.tensor(0.0).to(self.device) @@ -309,6 +313,11 @@ def reset_activation_shape(self): self.grad_layer = None self.meta_buffer = None + self.pipe_partition_input_meta_cache = None + self.pipe_partition_output_meta_cache = None + self.pipe_partition_grad_meta_cache = None + self.grad_partition_grad_layer_meta_cache = None + def train_batch(self, data_iter=None): """Progress the pipeline to train the next batch of data. The engine will ingest ``self.train_batch_size()`` total samples collectively across all workers. @@ -641,7 +650,9 @@ def _exec_forward_pass(self, buffer_id): # collect the partitioned input from the previous stage if self.is_pipe_partitioned and not self.is_first_stage(): - part_input = PartitionedTensor.from_meta(meta=inputs[0], + if self.pipe_partition_input_meta_cache is None: + self.pipe_partition_input_meta_cache = inputs[0].to('cpu') + part_input = PartitionedTensor.from_meta(meta=self.pipe_partition_input_meta_cache, local_part=inputs[1], group=self.grid.get_slice_parallel_group()) @@ -732,7 +743,9 @@ def _exec_backward_pass(self, buffer_id): # careful to also restore the computational graph of the tensors we partitioned. if self.is_pipe_partitioned: if self.is_grad_partitioned: - part_output = PartitionedTensor.from_meta(meta=outputs[0], + if self.pipe_partition_output_meta_cache is None: + self.pipe_partition_output_meta_cache = outputs[0].to('cpu') + part_output = PartitionedTensor.from_meta(meta=self.pipe_partition_output_meta_cache, local_part=outputs[1], group=self.grid.get_slice_parallel_group()) self.pipe_buffers['output_tensors'][buffer_id].data = part_output.full() @@ -745,7 +758,9 @@ def _exec_backward_pass(self, buffer_id): grad_tensors = self.grad_layer if self.is_grad_partitioned: #print(f'RANK={self.global_rank} BEFORE-BWD restoring grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}') - part_grad = PartitionedTensor.from_meta(meta=self.grad_layer[0], + if self.grad_partition_grad_layer_meta_cache is None: + self.grad_partition_grad_layer_meta_cache = self.grad_layer[0].to('cpu') + part_grad = PartitionedTensor.from_meta(meta=self.grad_partition_grad_layer_meta_cache, local_part=self.grad_layer[1], group=self.grid.get_slice_parallel_group()) grad_tensors = (part_grad.full(), *grad_tensors[2:]) @@ -1089,7 +1104,9 @@ def _exec_recv_grads(self, buffer_id): # XXX these shapes are hardcoded for Megatron # Restore partitioned output if it was partitioned and we are sending full gradients if self.is_pipe_partitioned and not self.is_grad_partitioned: - part_output = PartitionedTensor.from_meta(meta=outputs[0], + if self.pipe_partition_grad_meta_cache is None: + self.pipe_partition_grad_meta_cache = outputs[0].to('cpu') + part_output = PartitionedTensor.from_meta(meta=self.pipe_partition_grad_meta_cache, local_part=outputs[1], group=self.grid.get_slice_parallel_group()) outputs[0].data = part_output.full()