Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: DrownFish19 <[email protected]>
  • Loading branch information
DrownFish19 committed Jan 2, 2025
1 parent c2f8205 commit a19c630
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions nemo/utils/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self._on_batch_end("validation_step_timing in s", trainer, pl_module)


class TokenPerformanceCallback(Callback):
"""
Logs performance in token-level of train steps using nemo logger. Calculates
Expand Down Expand Up @@ -394,7 +395,7 @@ def on_train_epoch_start(self, trainer, pl_module):

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self._on_batch_end("train_step_token_performance", pl_module)

elapsed_time = self.timer["train_step_token_performance"]

# sum local tokens
Expand All @@ -404,14 +405,16 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
torch.distributed.all_reduce(total_tokens, group=parallel_state.get_data_parallel_group())
total_tokens_per_second = float(total_tokens) / elapsed_time
total_tokens_per_second_per_device = int(total_tokens_per_second / float(torch.distributed.get_world_size()))

# sum local tokens
effective_tokens = sum(batch["token_count"])
# sum tokens cross all data_parallel_group
effective_tokens = torch.tensor([effective_tokens]).cuda()
torch.distributed.all_reduce(effective_tokens, group=parallel_state.get_data_parallel_group())
effective_tokens_per_second = float(effective_tokens) / elapsed_time
effective_tokens_per_second_per_device = int(effective_tokens_per_second / float(torch.distributed.get_world_size()))
effective_tokens_per_second_per_device = int(
effective_tokens_per_second / float(torch.distributed.get_world_size())
)

pl_module.log(
'effective_tokens_per_second_per_device',
Expand Down

0 comments on commit a19c630

Please sign in to comment.