Skip to content

Commit

Permalink
use torch stream
Browse files Browse the repository at this point in the history
  • Loading branch information
zhijxu-MS committed Feb 21, 2024
1 parent 8092a89 commit 3c0043a
Showing 1 changed file with 9 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,17 @@ def _build_graph(self, graph_transformer_config):
@TrackTime(ORTModuleInitPhase.CREATE_SESSION)
def _create_execution_agent(self):
"""Creates an InferenceAgent that can run forward graph on an inference model"""
import os

self._runtime_options.use_torch_stream = os.environ.get("ORTMODULE_USE_TORCH_STREAM", "0") == "1"
if self._runtime_options.use_torch_stream:
self._runtime_options.use_external_gpu_allocator = False
session_options, providers, provider_options = self._get_session_config()
if self._runtime_options.use_torch_stream:
# torch_stream = str(torch.cuda.Stream().cuda_stream)
torch_stream = os.environ["ORTMODULE_TORCH_STREAM"]
print(f"--------ortmodule will use torch stream: {torch_stream}--------")
provider_options[0]["user_compute_stream"] = str(torch_stream)
self._execution_agent = InferenceAgent(
self._onnx_models.optimized_model.SerializeToString(), session_options, providers, provider_options
)
Expand Down

0 comments on commit 3c0043a

Please sign in to comment.