-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Performance] How can I forcefully assign nodes to CUDA EP? #17930
Comments
You can try ORT 1.16.1. Running some nodes in CPU does not always prevent CUDA Graph in ORT 1.16.1. However, if there is tensor need copy from host to device, or device to host, that will prevent CUDA Graph. If that does not work, try session.disable_fallback() to see whether it could help. If the above does not work, you will need use some offline script (like this) to remove shape computation nodes, and replace with initializers. |
At the cost of performance? Whilst the CUDA EP has a Shape operator, it's using the CPU EPs implementation as the shape information is in CPU allocated memory not CUDA memory.
The output of the CUDA Shape node is CPU based memory.
What's generally happening is that after a Shape node some manipulations happen to this CPU based data (Gather, Slice, Unsqueeze, Concat type things) and it's less efficient to attempt to do that on CUDA. onnxruntime/onnxruntime/core/framework/fallback_cpu_capability.cc Lines 97 to 101 in dad70ad
If you set log level to INFO you should see messages from here: onnxruntime/onnxruntime/core/framework/fallback_cpu_capability.cc Lines 161 to 163 in dad70ad
|
Yes, just to compare performance as-is vs using CUDA Graphs. However if even the CUDA Shape node outputs to CPU memory then sounds like it still won't work with CUDA Graphs. I'll try @tianleiwu's suggestions and report back. |
In ORT 1.16.1, when the shape computation (in CPU) ends with a Reshape node (That's the common use case), it could still work with CUDA graph. |
The model runs fine with CUDA Graphs after upgrading to ORT 1.16.1. Thanks! |
Describe the issue
Similar to #16863, I have a model which are assigned to CPU with the following warning:
The nodes placed on CPU EP are
Shape
and nodes that consume from theShape
output (such asGather
,Unsqueeze
,Concat
in my model). I want to run the model using CUDA Graphs, so I need all nodes to be placed on the CUDA EP. Is there a way I can force assignment of these nodes onto CUDA EP?To reproduce
Urgency
No response
Platform
Linux
OS Version
Ubuntu 20.04
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.13.1
ONNX Runtime API
Python
Architecture
X64
Execution Provider
CUDA
Execution Provider Library Version
No response
Model File
No response
Is this a quantized model?
No
The text was updated successfully, but these errors were encountered: