-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support user_compute_stream for rocm ep #19619
Conversation
Please resolve the build error in pipelines. |
Please add document for ROCm provider options after this pull request. |
The modification of cuda ep makes test failed on user_compute_stream testcase, so I revert the change of cuda ep, and make rocm ep same as cuda ep. |
onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc
Outdated
Show resolved
Hide resolved
Please add a test case in test/python/onnxruntime_test_python.py |
### Description <!-- Describe your changes. --> * Implement `user_compute_stream` python api for TensorRT EP * Using this option will implicitly set `has_user_compute_stream` as `true` * Extend existing TRTEP unit test to verify `user_compute_stream` option * This has been verified in local pytorch env, with `torch.cuda.Stream()` passing into `user_compute_stream`: ```python ... # Before inference if torch.cuda.is_available(): s = torch.cuda.Stream() option = {"user_compute_stream": str(s.cuda_stream)} sess.set_providers(["TensorrtExecutionProvider"], [option]) options = sess.get_provider_options() assert "TensorrtExecutionProvider" in options assert options["TensorrtExecutionProvider"].get("user_compute_stream", "") == str(s.cuda_stream) assert options["TensorrtExecutionProvider"].get("has_user_compute_stream", "") == "1" ... ``` ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Align with existing `user_compute_stream` python implementations for [CUDA EP](https://github.com/microsoft/onnxruntime/pull/19229)/[ROCm EP](#19619)
) ### Description <!-- Describe your changes. --> * Implement `user_compute_stream` python api for TensorRT EP * Using this option will implicitly set `has_user_compute_stream` as `true` * Extend existing TRTEP unit test to verify `user_compute_stream` option * This has been verified in local pytorch env, with `torch.cuda.Stream()` passing into `user_compute_stream`: ```python ... # Before inference if torch.cuda.is_available(): s = torch.cuda.Stream() option = {"user_compute_stream": str(s.cuda_stream)} sess.set_providers(["TensorrtExecutionProvider"], [option]) options = sess.get_provider_options() assert "TensorrtExecutionProvider" in options assert options["TensorrtExecutionProvider"].get("user_compute_stream", "") == str(s.cuda_stream) assert options["TensorrtExecutionProvider"].get("has_user_compute_stream", "") == "1" ... ``` ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Align with existing `user_compute_stream` python implementations for [CUDA EP](https://github.com/microsoft/onnxruntime/pull/19229)/[ROCm EP](microsoft#19619)
<!-- Describe your changes. --> According to the pr #19229 supporting cuda EP use external compute stream, we add support for rocm EP. And when we testing this feature with torch, we found torch use stream 0 for the default stream, and `torch.cuda.current_stream()` returns `0` for current stream, but ort treat `0` or `nullptr` as invalid, and reset has_user_compute_stream to false. Will remove has_user_compute_stream option in the future. <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> The motivation for this pr is that we want to use torch.cuda.graph to capture ort running kernel, which requires torch and ort are running in the same stream, so we use this API to set ort's working stream.
Description
According to the pr #19229 supporting cuda EP use external compute stream, we add support for rocm EP.
And when we testing this feature with torch, we found torch use stream 0 for the default stream, and
torch.cuda.current_stream()
returns0
for current stream, but ort treat0
ornullptr
as invalid, and reset has_user_compute_stream to false.This makes it confusing that we don't know if it works or not, so we add a warning log for invalid compute_stream.
Motivation and Context
The motivation for this pr is that we want to use torch.cuda.graph to capture ort running kernel, which requires torch and ort are running in the same stream, so we use this API to set ort's working stream.