Skip to content

Commit

Permalink
Disable nvtx decorator to avoid graph break (#5697)
Browse files Browse the repository at this point in the history
`instrument_w_nvtx` breaks a graph as `range_push` and `range_pop`
return a non-tensor int.
This PR disables the decorator to avoid the break graph.

This actually impacts the performance. In my environment, the training
iteration time using Llama-3-8B/4GPUs/ZeRO1 is improved from 3.02s ->
2.54s.

---------

Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
tohtana and loadams authored Jun 26, 2024
1 parent e9ffe02 commit b421e8c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
3 changes: 3 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3613,6 +3613,9 @@ def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwarg
"""Compile the module using the specified backend and kwargs.
If a compiler_fn is set, it will be used instead of torch.compile().
"""
# Avoid graph breaks
deepspeed.utils.nvtx.enable_nvtx = False

if not is_compile_supported():
raise RuntimeError("compile is not supported in your version of PyTorch.")

Expand Down
8 changes: 6 additions & 2 deletions deepspeed/utils/nvtx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,19 @@

from deepspeed.accelerator import get_accelerator

enable_nvtx = True


def instrument_w_nvtx(func):
"""decorator that causes an NVTX range to be recorded for the duration of the
function call."""

def wrapped_fn(*args, **kwargs):
get_accelerator().range_push(func.__qualname__)
if enable_nvtx:
get_accelerator().range_push(func.__qualname__)
ret_val = func(*args, **kwargs)
get_accelerator().range_pop()
if enable_nvtx:
get_accelerator().range_pop()
return ret_val

return wrapped_fn

0 comments on commit b421e8c

Please sign in to comment.