diff --git a/nemo/lightning/pytorch/callbacks/nsys.py b/nemo/lightning/pytorch/callbacks/nsys.py index 2a6bc3668b94..13b059011426 100644 --- a/nemo/lightning/pytorch/callbacks/nsys.py +++ b/nemo/lightning/pytorch/callbacks/nsys.py @@ -21,6 +21,18 @@ from nemo.utils.get_rank import get_rank +def get_current_epoch_step(trainer) -> int: + """ + Get the value of step within an epoch. + """ + if hasattr(trainer.strategy, 'current_epoch_step'): + return trainer.strategy.current_epoch_step + return max( + trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.optimizer.step.current.completed, + trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.current.completed, + ) + + class NsysCallback(Callback): """ A PyTorch Lightning callback for NVIDIA Nsight Systems (Nsys) profiling. @@ -67,39 +79,41 @@ def __init__( f'and end_step: {self._nsys_profile_end_step}' ) + def _rank_is_active(self, trainer): + # TODO(@akoumparouli): is this function cache-able? + from lightning.pytorch.strategies import SingleDeviceStrategy + + if isinstance(trainer.strategy, SingleDeviceStrategy): + return True + if not torch.distributed.is_initialized(): + return True + return get_rank() in self._nsys_profile_ranks + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx: int) -> Optional[int]: """PyTorch Lightning hook: https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-batch-start We use it here to enable nsys profiling. """ + if not self._rank_is_active(trainer) or trainer.strategy.root_device.type != 'cuda': + return - device = trainer.strategy.root_device - try: - # Not all strategies have this. e.g.: - # AttributeError: 'SingleDeviceStrategy' object has no attribute 'current_epoch_step' - current_step = trainer.strategy.current_epoch_step - except AttributeError: - current_step = self._nsys_profile_start_step - if device.type == 'cuda': - if current_step == self._nsys_profile_start_step and get_rank() in self._nsys_profile_ranks: - torch.cuda.cudart().cudaProfilerStart() - if self._nsys_profile_gen_shape: - torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__() - else: - torch.autograd.profiler.emit_nvtx().__enter__() + current_step = get_current_epoch_step(trainer) + if current_step == self._nsys_profile_start_step: + torch.cuda.cudart().cudaProfilerStart() + if self._nsys_profile_gen_shape: + torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__() + else: + torch.autograd.profiler.emit_nvtx().__enter__() def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int) -> None: """PyTorch Lightning hook: https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-batch-end We use it here to enable nsys profiling. """ + if not self._rank_is_active(trainer) or trainer.strategy.root_device.type != 'cuda': + return - device = trainer.strategy.root_device - try: - current_step = trainer.strategy.current_epoch_step - except AttributeError: - current_step = self._nsys_profile_end_step - if device.type == 'cuda': - if current_step == self._nsys_profile_end_step and get_rank() in self._nsys_profile_ranks: - torch.cuda.cudart().cudaProfilerStop() - torch.autograd.profiler.emit_nvtx().__exit__(None, None, None) + current_step = get_current_epoch_step(trainer) + if current_step == self._nsys_profile_end_step: + torch.cuda.cudart().cudaProfilerStop() + torch.autograd.profiler.emit_nvtx().__exit__(None, None, None)