-
Notifications
You must be signed in to change notification settings - Fork 451
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Have trouble using sherpa-onnx-offline-websocket-server with cuda provider #1053
Comments
Does the server work fine when you use CPU? |
Yes, it works fine when using cpu provider. |
Could you tell us how you start the server? |
|
Could you change sherpa-onnx/sherpa-onnx/csrc/offline-websocket-server-impl.cc Lines 95 to 98 in 1f95bff
to recognizer_.DecodeStreams(p_ss.data(), size);
lock.unlock(); recompile, and re-try? |
It works fine now. So is it a bug here? |
I think it is a bug of onnxruntime. When using CPU, onnxruntime session is thread-safe. please see |
Hi @csukuangfj, is the issue coming because @Vergissmeinicht is using local onnxruntime, |
I build sherpa with no local onnxruntime. The installation of onnxruntime is provided by the cmake. |
@csukuangfj Server been running with 200k wav files recognized, everything works fine except that memory consumption seems to increase by nearly 3G. No more modification to source code. Is it possible that memory leak may happen? |
Is CPU RAM or GPU RAM increased to 3G? Do you mean 20 000 wavs or just 200 wav files? |
@Vergissmeinicht Could you look into this comment? |
Been serving for 2days and now the memory consumption keeps stable. No more worry about memory leak! : ) |
Replied to this comment already. I build the whole project inside a docker without any onnxruntime installed. |
Are you also running sherpa-onnx inside the docker container? |
Yes. I use nvidia/cuda:11.1.1-cudnn8-devel-ubuntu20.04 as my base docker. |
Can it be closed now? |
So it makes no difference whether the recognizer do decode after or before the unlock? |
For the CUDA provider, since onnxruntime.session is not thread-safe, we have to do decode first, and then unlock. For the CPU provider, onnxruntime.session is thread-safe, so we can unlock first and then decode. |
I follow the instruction from (https://k2-fsa.github.io/sherpa/onnx/websocket/offline-websocket.html ) to start a non-streaming websocket server of transducer models. It works well with the client as well. But when I try to run the client in multithread, which means, several thread using websocket client to recognize wav files one by one in the same time, server raises cuda error:
2024-06-24 09:47:01.083093543 [E:onnxruntime:, cuda_call.cc:116 CudaCall] CUDA failure 700: an illegal memory access was encountered ; GPU=0 ; hostname=a2d9f82c2221 ; file=/onnxruntime_src/onnxruntime/core/providers/cuda/cuda_execution_provider.cc ; line=408 ; expr=cudaStreamSynchronize(static_cast<cudaStream_t>(stream_)); 2024-06-24 09:47:01.083005575 [E:onnxruntime:, cuda_call.cc:116 CudaCall] CUDA failure 700: an illegal memory access was encountered ; GPU=0 ; hostname=a2d9f82c2221 ; file=/onnxruntime_src/onnxruntime/core/providers/cuda/gpu_data_transfer.cc ; line=73 ; expr=cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToHost, static_cast<cudaStream_t>(stream.GetHandle())); terminate called after throwing an instance of 'Ort::Exception' what(): CUDA failure 700: an illegal memory access was encountered ; GPU=0 ; hostname=a2d9f82c2221 ; file=/onnxruntime_src/onnxruntime/core/providers/cuda/cuda_execution_provider.cc ; line=408 ; expr=cudaStreamSynchronize(static_cast<cudaStream_t>(stream_)); Aborted
My server runs on GeForce RTX 4090 / driver 535.104.05 / CUDA version: 12.2.
Glad to have your help.
The text was updated successfully, but these errors were encountered: