From 3c0043a3009f69c6eb536642a792a12aca4c759c Mon Sep 17 00:00:00 2001 From: zhijxu Date: Wed, 21 Feb 2024 15:39:16 +0800 Subject: [PATCH] use torch stream --- .../python/training/ortmodule/_inference_manager.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py index 6690af9b71bf1..7e063ab1fbdc0 100644 --- a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py @@ -222,8 +222,17 @@ def _build_graph(self, graph_transformer_config): @TrackTime(ORTModuleInitPhase.CREATE_SESSION) def _create_execution_agent(self): """Creates an InferenceAgent that can run forward graph on an inference model""" + import os + self._runtime_options.use_torch_stream = os.environ.get("ORTMODULE_USE_TORCH_STREAM", "0") == "1" + if self._runtime_options.use_torch_stream: + self._runtime_options.use_external_gpu_allocator = False session_options, providers, provider_options = self._get_session_config() + if self._runtime_options.use_torch_stream: + # torch_stream = str(torch.cuda.Stream().cuda_stream) + torch_stream = os.environ["ORTMODULE_TORCH_STREAM"] + print(f"--------ortmodule will use torch stream: {torch_stream}--------") + provider_options[0]["user_compute_stream"] = str(torch_stream) self._execution_agent = InferenceAgent( self._onnx_models.optimized_model.SerializeToString(), session_options, providers, provider_options )