diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 366db3a972a3..d40141132aaf 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -3613,6 +3613,9 @@ def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwarg """Compile the module using the specified backend and kwargs. If a compiler_fn is set, it will be used instead of torch.compile(). """ + # Avoid graph breaks + deepspeed.utils.nvtx.enable_nvtx = False + if not is_compile_supported(): raise RuntimeError("compile is not supported in your version of PyTorch.") diff --git a/deepspeed/utils/nvtx.py b/deepspeed/utils/nvtx.py index 3823599e7bf2..7c566480a86a 100644 --- a/deepspeed/utils/nvtx.py +++ b/deepspeed/utils/nvtx.py @@ -5,15 +5,19 @@ from deepspeed.accelerator import get_accelerator +enable_nvtx = True + def instrument_w_nvtx(func): """decorator that causes an NVTX range to be recorded for the duration of the function call.""" def wrapped_fn(*args, **kwargs): - get_accelerator().range_push(func.__qualname__) + if enable_nvtx: + get_accelerator().range_push(func.__qualname__) ret_val = func(*args, **kwargs) - get_accelerator().range_pop() + if enable_nvtx: + get_accelerator().range_pop() return ret_val return wrapped_fn