From b421e8c8f31af254b63ad6e9839f617ab6d9c060 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Wed, 26 Jun 2024 15:30:56 -0700 Subject: [PATCH] Disable nvtx decorator to avoid graph break (#5697) `instrument_w_nvtx` breaks a graph as `range_push` and `range_pop` return a non-tensor int. This PR disables the decorator to avoid the break graph. This actually impacts the performance. In my environment, the training iteration time using Llama-3-8B/4GPUs/ZeRO1 is improved from 3.02s -> 2.54s. --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/runtime/engine.py | 3 +++ deepspeed/utils/nvtx.py | 8 ++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 366db3a972a3..d40141132aaf 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -3613,6 +3613,9 @@ def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwarg """Compile the module using the specified backend and kwargs. If a compiler_fn is set, it will be used instead of torch.compile(). """ + # Avoid graph breaks + deepspeed.utils.nvtx.enable_nvtx = False + if not is_compile_supported(): raise RuntimeError("compile is not supported in your version of PyTorch.") diff --git a/deepspeed/utils/nvtx.py b/deepspeed/utils/nvtx.py index 3823599e7bf2..7c566480a86a 100644 --- a/deepspeed/utils/nvtx.py +++ b/deepspeed/utils/nvtx.py @@ -5,15 +5,19 @@ from deepspeed.accelerator import get_accelerator +enable_nvtx = True + def instrument_w_nvtx(func): """decorator that causes an NVTX range to be recorded for the duration of the function call.""" def wrapped_fn(*args, **kwargs): - get_accelerator().range_push(func.__qualname__) + if enable_nvtx: + get_accelerator().range_push(func.__qualname__) ret_val = func(*args, **kwargs) - get_accelerator().range_pop() + if enable_nvtx: + get_accelerator().range_pop() return ret_val return wrapped_fn