Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] CUDA kernel not found in registries for Op type: ScatterND #21148

Open
11721206 opened this issue Jun 23, 2024 · 8 comments
Open
Labels
ep:CUDA issues related to the CUDA execution provider performance issues related to performance regressions stale issues that have not been addressed in a while; categorized by a bot

Comments

@11721206
Copy link

11721206 commented Jun 23, 2024

Describe the issue

Hello, I want to know why i download the onnxruntime v1.17.3 library, still occurs "CUDA kernel not found in registries for Op type: ScatterND", and inference performance is too slow, is there any steps i should do?

I think there are some issues and PR have fixed this issue ,but I still get this kind of problem.

To reproduce

1: I export GPT-SoVITS model, with opset version 17(is torch.onnx.export highest version)
2: use C++ onnxruntime 1.17.3 to inference on CUDAExecutionProvider
3: in VITS module, the inference speed is CUDA faster than CPU
4: but in GPT Module, the CUDA speed is slower than CPU, I get some logs find "CUDA kernel not found in registries for Op type: ScatterND" and some MemcpyToHost and MemcpyFromHost.
5: torch version 2.2, onnx version 1.16

Urgency

Download from release

Platform

Linux

OS Version

ubuntu18.04

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

onnxruntime-gpu 1.17.3

ONNX Runtime API

C++

Architecture

X64

Execution Provider

CUDA

Execution Provider Library Version

No response

Model File

No response

Is this a quantized model?

No

@github-actions github-actions bot added the ep:CUDA issues related to the CUDA execution provider label Jun 23, 2024
@11721206 11721206 changed the title [Performance] [Performance] CUDA kernel not found in registries for Op type: ScatterND Jun 23, 2024
@mindest
Copy link
Contributor

mindest commented Jun 25, 2024

Could you also share the code you used in the To reproduce section? CUDA kernel for ScatterND shouldn't be missing.

@11721206
Copy link
Author

11721206 commented Jun 25, 2024

Could you also share the code you used in the To reproduce section? CUDA kernel for ScatterND shouldn't be missing.

I just use this code (https://github.com/RVC-Boss/GPT-SoVITS/blob/main/GPT_SoVITS/onnx_export.py) convert torch model into onnx,
and then use

import onnxruntime
sessopt = onnxruntime.SessionOptions()
sessopt.log_severity_level = 1
sess = onnxruntime.InferenceSession("onnx/onnx_cc/onnx_t2s_cc_fsdec.onnx", sess_options=sessopt,  providers=["CUDAExecutionProvider"])

In python InferenceSession, I print the sess.get_providers() result, and show ["CUDAExecutionProvider", "CPUExecutionProvider"], it mean some ops work on CPU.
and i see my log there my many CUDA kernel not found in register and MemcpyFrom(To)Host, as belows

1719309351876
image

I use these onnx models in C++ and then still occurs that problems and the inference speed is slow.

There is one thing that i want to mention, in the onnx_export.py export vits.onnx, there isn't such problems.
Waiting for reply and advice for solutions. Thanks So much.

@sophies927 sophies927 added the performance issues related to performance regressions label Jun 27, 2024
@tianleiwu
Copy link
Contributor

tianleiwu commented Jun 28, 2024

@11721206, Could you try 1.18.0 or 1.18.1?

1.17.3 supports ScatterND up to opset 13 but your model is opset 17:

1.18.1 supports ScatterND up to opset 18:

@11721206
Copy link
Author

11721206 commented Jul 1, 2024

@11721206, Could you try 1.18.0 or 1.18.1?

1.17.3 supports ScatterND up to opset 13 but your model is opset 17:

1.18.1 supports ScatterND up to opset 18:

the torch version does not support opset 18; and i use onnx convert_version to convert opset version ,but I got the same issue as opset 17,ie "CUDA kernel not found in registries for Op type: ScatterND"

@mindest
Copy link
Contributor

mindest commented Jul 1, 2024

ONNX_OPERATOR_KERNEL_EX(ScatterND,
kOnnxDomain,
13,

Such lines indicate that the operator is supported since opset version 13, not up to 13.

ScatterND is updated since opset 16 (new attribute reduction). The exported graph is for opset 17, but ORT 1.17.3 still has old def for it, therefore it fails to find the compatible CUDA kernel.

@11721206 please try, as suggested above, upgrading onnxruntime-gpu to 1.18.0 or 1.18.1. It should work with opset 17 and you don't have to switch the opset version you use.

pip install onnxruntime-gpu==1.18.0  # or 1.18.1

@11721206
Copy link
Author

11721206 commented Jul 1, 2024

ONNX_OPERATOR_KERNEL_EX(ScatterND,
kOnnxDomain,
13,

Such lines indicate that the operator is supported since opset version 13, not up to 13.
ScatterND is updated since opset 16 (new attribute reduction). The exported graph is for opset 17, but ORT 1.17.3 still has old def for it, therefore it fails to find the compatible CUDA kernel.

@11721206 please try, as suggested above, upgrading onnxruntime-gpu to 1.18.0 or 1.18.1. It should work with opset 17 and you don't have to switch the opset version you use.

pip install onnxruntime-gpu==1.18.0  # or 1.18.1

I have tried and tested, but the result is the same. And now I try to use list to avoid using ScatterND can solve that problem. And is there any othert suggestions that i can try? and I will share my experiment's results .

@mindest
Copy link
Contributor

mindest commented Jul 1, 2024

@11721206 That is weird, I can reproduce and fix the warning after upgrading on my end. Could you check if you have different but older onnxruntime packages, e.g., onnxruntime-training, in the environment (pip list | grep onnxruntime)? If any, please uninstall them and reinstall onnxruntime-gpu. I don't have another clue for now.

Copy link
Contributor

This issue has been automatically marked as stale due to inactivity and will be closed in 30 days if no further activity occurs. If further support is needed, please provide an update and/or more details.

@github-actions github-actions bot added the stale issues that have not been addressed in a while; categorized by a bot label Jul 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:CUDA issues related to the CUDA execution provider performance issues related to performance regressions stale issues that have not been addressed in a while; categorized by a bot
Projects
None yet
Development

No branches or pull requests

4 participants