Skip to content

Commit

Permalink
Merge branch 'master' into mrwyattii/pydantic-2-support
Browse files Browse the repository at this point in the history
  • Loading branch information
adk9 authored Jun 27, 2024
2 parents f973393 + b421e8c commit 09fa6b5
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
3 changes: 3 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3613,6 +3613,9 @@ def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwarg
"""Compile the module using the specified backend and kwargs.
If a compiler_fn is set, it will be used instead of torch.compile().
"""
# Avoid graph breaks
deepspeed.utils.nvtx.enable_nvtx = False

if not is_compile_supported():
raise RuntimeError("compile is not supported in your version of PyTorch.")

Expand Down
8 changes: 6 additions & 2 deletions deepspeed/utils/nvtx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,19 @@

from deepspeed.accelerator import get_accelerator

enable_nvtx = True


def instrument_w_nvtx(func):
"""decorator that causes an NVTX range to be recorded for the duration of the
function call."""

def wrapped_fn(*args, **kwargs):
get_accelerator().range_push(func.__qualname__)
if enable_nvtx:
get_accelerator().range_push(func.__qualname__)
ret_val = func(*args, **kwargs)
get_accelerator().range_pop()
if enable_nvtx:
get_accelerator().range_pop()
return ret_val

return wrapped_fn

0 comments on commit 09fa6b5

Please sign in to comment.