Skip to content

Commit

Permalink
Fix memory leak caused by context parallelism hanging references by o…
Browse files Browse the repository at this point in the history
…megaconf (NVIDIA#8299)

* save cp_size to self

Signed-off-by: Jimmy Zhang <[email protected]>

* use parallel_state instead of self

Signed-off-by: Jimmy Zhang <[email protected]>

---------

Signed-off-by: Jimmy Zhang <[email protected]>
Co-authored-by: Jimmy Zhang <[email protected]>
Co-authored-by: Eric Harper <[email protected]>
Signed-off-by: Alexandros Koumparoulis <[email protected]>
  • Loading branch information
3 people authored and akoumpa committed Feb 5, 2024
1 parent cc8d888 commit 05ad2b7
Showing 1 changed file with 3 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -910,11 +910,11 @@ def get_batch(self, data_iterator, tuning):
return batch

def get_batch_on_this_context_parallel_rank(self, batch):
cp_size = self.cfg.get('context_parallel_size', 1)
num_valid_tokens_in_ub = None
if 'loss_mask' in batch and batch['loss_mask'] is not None:
num_valid_tokens_in_ub = batch['loss_mask'].sum()

cp_size = parallel_state.get_context_parallel_world_size()
if cp_size > 1:
cp_rank = parallel_state.get_context_parallel_rank()
for key, val in batch.items():
Expand Down Expand Up @@ -1011,7 +1011,7 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_
def loss_func(output_tensor):
# Loss for a micro-batch (ub)
loss_for_ub = self.loss_func(batch['loss_mask'], batch['num_valid_tokens_in_ub'], output_tensor)
cp_size = self.cfg.get('context_parallel_size', 1)
cp_size = parallel_state.get_context_parallel_world_size()
if validation_step and not self.cfg.data.get('validation_drop_last', True):
num_valid_tokens_in_ub = batch['num_valid_tokens_in_ub']
if loss_for_ub.isnan():
Expand Down Expand Up @@ -1167,8 +1167,7 @@ def loss_func(self, loss_mask, num_valid_tokens_in_ub, output_tensor):
loss_mask = loss_mask.view(-1).float()
# TODO: add nemo version here
loss = torch.sum(losses.view(-1) * loss_mask) / num_valid_tokens_in_ub # sequence level nll
cp_size = self.cfg.get('context_parallel_size', 1)
if cp_size > 1:
if parallel_state.get_context_parallel_world_size() > 1:
torch.distributed.all_reduce(loss, group=parallel_state.get_context_parallel_group())
return loss

Expand Down

0 comments on commit 05ad2b7

Please sign in to comment.