Skip to content

Commit

Permalink
Fix SingleDeviceStrategy support in Nsys callback (#11574)
Browse files Browse the repository at this point in the history
* fix for SingleDeviceStrategy

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* mini refactor

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* typo

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* fix

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Apply isort and black reformatting

Signed-off-by: akoumpa <[email protected]>

---------

Signed-off-by: Alexandros Koumparoulis <[email protected]>
Signed-off-by: akoumpa <[email protected]>
Co-authored-by: akoumpa <[email protected]>
  • Loading branch information
akoumpa and akoumpa authored Dec 13, 2024
1 parent aa6eba2 commit 171a7af
Showing 1 changed file with 37 additions and 23 deletions.
60 changes: 37 additions & 23 deletions nemo/lightning/pytorch/callbacks/nsys.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@
from nemo.utils.get_rank import get_rank


def get_current_epoch_step(trainer) -> int:
"""
Get the value of step within an epoch.
"""
if hasattr(trainer.strategy, 'current_epoch_step'):
return trainer.strategy.current_epoch_step
return max(
trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.optimizer.step.current.completed,
trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.current.completed,
)


class NsysCallback(Callback):
"""
A PyTorch Lightning callback for NVIDIA Nsight Systems (Nsys) profiling.
Expand Down Expand Up @@ -67,39 +79,41 @@ def __init__(
f'and end_step: {self._nsys_profile_end_step}'
)

def _rank_is_active(self, trainer):
# TODO(@akoumparouli): is this function cache-able?
from lightning.pytorch.strategies import SingleDeviceStrategy

if isinstance(trainer.strategy, SingleDeviceStrategy):
return True
if not torch.distributed.is_initialized():
return True
return get_rank() in self._nsys_profile_ranks

def on_train_batch_start(self, trainer, pl_module, batch, batch_idx: int) -> Optional[int]:
"""PyTorch Lightning hook:
https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-batch-start
We use it here to enable nsys profiling.
"""
if not self._rank_is_active(trainer) or trainer.strategy.root_device.type != 'cuda':
return

device = trainer.strategy.root_device
try:
# Not all strategies have this. e.g.:
# AttributeError: 'SingleDeviceStrategy' object has no attribute 'current_epoch_step'
current_step = trainer.strategy.current_epoch_step
except AttributeError:
current_step = self._nsys_profile_start_step
if device.type == 'cuda':
if current_step == self._nsys_profile_start_step and get_rank() in self._nsys_profile_ranks:
torch.cuda.cudart().cudaProfilerStart()
if self._nsys_profile_gen_shape:
torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__()
else:
torch.autograd.profiler.emit_nvtx().__enter__()
current_step = get_current_epoch_step(trainer)
if current_step == self._nsys_profile_start_step:
torch.cuda.cudart().cudaProfilerStart()
if self._nsys_profile_gen_shape:
torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__()
else:
torch.autograd.profiler.emit_nvtx().__enter__()

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int) -> None:
"""PyTorch Lightning hook:
https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-batch-end
We use it here to enable nsys profiling.
"""
if not self._rank_is_active(trainer) or trainer.strategy.root_device.type != 'cuda':
return

device = trainer.strategy.root_device
try:
current_step = trainer.strategy.current_epoch_step
except AttributeError:
current_step = self._nsys_profile_end_step
if device.type == 'cuda':
if current_step == self._nsys_profile_end_step and get_rank() in self._nsys_profile_ranks:
torch.cuda.cudart().cudaProfilerStop()
torch.autograd.profiler.emit_nvtx().__exit__(None, None, None)
current_step = get_current_epoch_step(trainer)
if current_step == self._nsys_profile_end_step:
torch.cuda.cudart().cudaProfilerStop()
torch.autograd.profiler.emit_nvtx().__exit__(None, None, None)

0 comments on commit 171a7af

Please sign in to comment.