diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index a37c73c015dba..e3a912a543f47 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -910,11 +910,11 @@ def get_batch(self, data_iterator, tuning): return batch def get_batch_on_this_context_parallel_rank(self, batch): - cp_size = self.cfg.get('context_parallel_size', 1) num_valid_tokens_in_ub = None if 'loss_mask' in batch and batch['loss_mask'] is not None: num_valid_tokens_in_ub = batch['loss_mask'].sum() + cp_size = parallel_state.get_context_parallel_world_size() if cp_size > 1: cp_rank = parallel_state.get_context_parallel_rank() for key, val in batch.items(): @@ -1011,7 +1011,7 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ def loss_func(output_tensor): # Loss for a micro-batch (ub) loss_for_ub = self.loss_func(batch['loss_mask'], batch['num_valid_tokens_in_ub'], output_tensor) - cp_size = self.cfg.get('context_parallel_size', 1) + cp_size = parallel_state.get_context_parallel_world_size() if validation_step and not self.cfg.data.get('validation_drop_last', True): num_valid_tokens_in_ub = batch['num_valid_tokens_in_ub'] if loss_for_ub.isnan(): @@ -1167,8 +1167,7 @@ def loss_func(self, loss_mask, num_valid_tokens_in_ub, output_tensor): loss_mask = loss_mask.view(-1).float() # TODO: add nemo version here loss = torch.sum(losses.view(-1) * loss_mask) / num_valid_tokens_in_ub # sequence level nll - cp_size = self.cfg.get('context_parallel_size', 1) - if cp_size > 1: + if parallel_state.get_context_parallel_world_size() > 1: torch.distributed.all_reduce(loss, group=parallel_state.get_context_parallel_group()) return loss