From a19c63024de1755eae10473137d7bda2c204d754 Mon Sep 17 00:00:00 2001 From: DrownFish19 Date: Thu, 2 Jan 2025 07:22:15 +0000 Subject: [PATCH] Apply isort and black reformatting Signed-off-by: DrownFish19 --- nemo/utils/exp_manager.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index f9154ebb9f54..27580d14f2e5 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -352,6 +352,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self._on_batch_end("validation_step_timing in s", trainer, pl_module) + class TokenPerformanceCallback(Callback): """ Logs performance in token-level of train steps using nemo logger. Calculates @@ -394,7 +395,7 @@ def on_train_epoch_start(self, trainer, pl_module): def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self._on_batch_end("train_step_token_performance", pl_module) - + elapsed_time = self.timer["train_step_token_performance"] # sum local tokens @@ -404,14 +405,16 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): torch.distributed.all_reduce(total_tokens, group=parallel_state.get_data_parallel_group()) total_tokens_per_second = float(total_tokens) / elapsed_time total_tokens_per_second_per_device = int(total_tokens_per_second / float(torch.distributed.get_world_size())) - + # sum local tokens effective_tokens = sum(batch["token_count"]) # sum tokens cross all data_parallel_group effective_tokens = torch.tensor([effective_tokens]).cuda() torch.distributed.all_reduce(effective_tokens, group=parallel_state.get_data_parallel_group()) effective_tokens_per_second = float(effective_tokens) / elapsed_time - effective_tokens_per_second_per_device = int(effective_tokens_per_second / float(torch.distributed.get_world_size())) + effective_tokens_per_second_per_device = int( + effective_tokens_per_second / float(torch.distributed.get_world_size()) + ) pl_module.log( 'effective_tokens_per_second_per_device',