Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel inference of multiple models in different threads #18806

Open
leugenea opened this issue Dec 13, 2023 · 5 comments
Open

Parallel inference of multiple models in different threads #18806

leugenea opened this issue Dec 13, 2023 · 5 comments
Labels
ep:CUDA issues related to the CUDA execution provider

Comments

@leugenea
Copy link

Describe the issue

Use case:

  1. I have three models that must run sequentially, some some inputs are constant, some inputs change,
  2. Model 2 uses some of the outputs of Model 1, Model 3 uses some of the outputs of Model 1 and Model 2,
  3. I'm trying to implement this pipeline optimally on CUDA using CUDA EP (not TRT because some inputs have dynamic shapes),
  4. Because simple version Run() method always copies inputs to GPU and outputs back to host, I decided to use IoBinding feature.

The code I've got does copy data to/from GPU only when I need it to and works just fine in 1 thread, but when I trying to execute multiple pipelines Model 1 → Model 2 → Model 3 simultaneously from different threads I starting to get errors like this:

2023-12-13 14:43:34.326899502 [E:onnxruntime:sdk, cuda_call.cc:119 CudaCall] CUDA failure 900: operation not permitted when stream is capturing ; GPU=0 ; hostname=test-server ; expr=cudaDeviceSynchronize(); 
error while inference CUDA failure 900: operation not permitted when stream is capturing ; GPU=0 ; hostname=test-server ; expr=cudaDeviceSynchronize(); 
2023-12-13 14:43:35.809486217 [E:onnxruntime:sdk, cuda_call.cc:119 CudaCall] CUDA failure 900: operation not permitted when stream is capturing ; GPU=0 ; hostname=test-server ; expr=cudaDeviceSynchronize(); 
error while inference CUDA failure 900: operation not permitted when stream is capturing ; GPU=0 ; hostname=test-server ; expr=cudaDeviceSynchronize(); 
2023-12-13 14:43:35.939625872 [E:onnxruntime:, sequential_executor.cc:494 ExecuteKernel] Non-zero status code returned while running Not node. Name:'Not_4064' Status Message: CUDA error cudaErrorStreamCaptureUnsupported:operation not permitted when stream is capturing
error while inference Non-zero status code returned while running Not node. Name:'Not_4064' Status Message: CUDA error cudaErrorStreamCaptureUnsupported:operation not permitted when stream is capturing

The errors can vary, but CUDA failure 900: operation not permitted when stream is capturing and CUDA failure 700: an illegal memory access was encountered are very common.

This issue could be reproduced on different PCs with different OS-s, CPUs and GPUs:

  1. Arch linux + i7-8700 + 1080 Ti,
  2. CentOS7 + Xeon Gold 6248R + V100.

So this is not a problem of the environment.

I think that problem is that I manage GPU memory in a wrong way but I've read all available docs a couple of times and read a lot of tests/examples and still cannot figure out what I'm doing wrong.

To reproduce

I've tried to reproduce this issue in the isolated program, but tried to use my original code as much as I could.
Reproducer is based on one of the official examples and IoBinding-s code is adopted from ONNX Runtime tests.

Attaching reproduser code as has 400+ loc: batch-model-explorer.cpp.

For reproducing I've used one of the public models: resnet50-v1-7.onnx.

Using ONNX Runtime v1.14.1 I've got to reproduce my issue with 4/6/8 threads. Snippet from reproducer output:

Thread #3 processing data #67
2023-12-13 13:06:03.321060776 [E:onnxruntime:cuda-multithread-test, cuda_call.cc:119 CudaCall] CUDA failure 700: an illegal memory access was encountered ; GPU=0 ; hostname=test-server ; expr=cudaStreamSynchronize(static_cast<cudaStream_t>(stream_)); 
2023-12-13 13:06:03.321077697 [E:onnxruntime:cuda-multithread-test, cuda_call.cc:119 CudaCall] CUDA failure 700: an illegal memory access was encountered ; GPU=0 ; hostname=test-server ; expr=cudaStreamSynchronize(static_cast<cudaStream_t>(stream_)); 
2023-12-13 13:06:03.321057015 [E:onnxruntime:cuda-multithread-test, cuda_call.cc:119 CudaCall] CUDA failure 700: an illegal memory access was encountered ; GPU=0 ; hostname=test-server ; expr=cudaStreamSynchronize(static_cast<cudaStream_t>(stream_)); 
error: /onnxruntime_src/onnxruntime/core/providers/cuda/cuda_call.cc:124 std::conditional_t<THRW, void, onnxruntime::common::Status> onnxruntime::CudaCall(ERRTYPE, const char*, const char*, ERRTYPE, const char*) [with ERRTYPE = cudaError; bool THRW = true; std::conditional_t<THRW, void, onnxruntime::common::Status> = void] /onnxruntime_src/onnxruntime/core/providers/cuda/cuda_call.cc:117 std::conditional_t<THRW, void, onnxruntime::common::Status> onnxruntime::CudaCall(ERRTYPE, const char*, const char*, ERRTYPE, const char*) [with ERRTYPE = cudaError; bool THRW = true; std::conditional_t<THRW, void, onnxruntime::common::Status> = void] CUDA failure 700: an illegal memory access was encountered ; GPU=0 ; hostname=test-server ; expr=cudaStreamSynchronize(static_cast<cudaStream_t>(GetHandle())); 

error: /onnxruntime_src/onnxruntime/core/providers/cuda/cuda_call.cc:124 std::conditional_t<THRW, void, onnxruntime::common::Status> onnxruntime::CudaCall(ERRTYPE, const char*, const char*, ERRTYPE, const char*) [with ERRTYPE = cudaError; bool THRW = true; std::conditional_t<THRW, void, onnxruntime::common::Status> = void] /onnxruntime_src/onnxruntime/core/providers/cuda/cuda_call.cc:117 std::conditional_t<THRW, void, onnxruntime::common::Status> onnxruntime::CudaCall(ERRTYPE, const char*, const char*, ERRTYPE, const char*) [with ERRTYPE = cudaError; bool THRW = true; std::conditional_t<THRW, void, onnxruntime::common::Status> = void] CUDA failure 700: an illegal memory access was encountered ; GPU=0 ; hostname=test-server ; expr=cudaStreamSynchronize(static_cast<cudaStream_t>(GetHandle())); 

error: /onnxruntime_src/onnxruntime/core/providers/cuda/cuda_call.cc:124 std::conditional_t<THRW, void, onnxruntime::common::Status> onnxruntime::CudaCall(ERRTYPE, const char*, const char*, ERRTYPE, const char*) [with ERRTYPE = cudaError; bool THRW = true; std::conditional_t<THRW, void, onnxruntime::common::Status> = void] /onnxruntime_src/onnxruntime/core/providers/cuda/cuda_call.cc:117 std::conditional_t<THRW, void, onnxruntime::common::Status> onnxruntime::CudaCall(ERRTYPE, const char*, const char*, ERRTYPE, const char*) [with ERRTYPE = cudaError; bool THRW = true; std::conditional_t<THRW, void, onnxruntime::common::Status> = void] CUDA failure 700: an illegal memory access was encountered ; GPU=0 ; hostname=test-server ; expr=cudaStreamSynchronize(static_cast<cudaStream_t>(GetHandle())); 

Thread #1 processing data #69
Thread #7 processing data #67

Using ONNX Runtime v1.16.3 (latest at the moment) I couldn't get errors using reproducer, but got same errors from production code, so I assume that problem is not solved but harder to reproduce.

Urgency

If this is indeed bug in ONNX Runtime then it at least blocks inference of one session on GPU in multithreaded apps.

If this is a problem in my code then there's a huge hole in docs describing how to use CUDA EP + IoBinding + threads.

Platform

Linux

OS Version

CentOS 7, Arch Linux

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

v1.14.1, v1.16.3

ONNX Runtime API

C++

Architecture

X64

Execution Provider

CUDA

Execution Provider Library Version

CUDA 11.8

@github-actions github-actions bot added ep:CUDA issues related to the CUDA execution provider ep:TensorRT issues related to TensorRT execution provider labels Dec 13, 2023
@leugenea leugenea changed the title Parallel inference of multiple nets in different threads Parallel inference of multiple models in different threads Dec 13, 2023
@jywu-msft jywu-msft removed the ep:TensorRT issues related to TensorRT execution provider label Dec 13, 2023
Copy link
Contributor

This issue has been automatically marked as stale due to inactivity and will be closed in 30 days if no further activity occurs. If further support is needed, please provide an update and/or more details.

@github-actions github-actions bot added the stale issues that have not been addressed in a while; categorized by a bot label Jan 13, 2024
@leugenea
Copy link
Author

@jywu-msft, is additional info needed?

@github-actions github-actions bot removed the stale issues that have not been addressed in a while; categorized by a bot label Jan 15, 2024
@mdubepsi
Copy link

I have a very similar problem when trying to "run" a model by multiple threads. I can run 'n' models in 'n' distinct threads without problem, but as soon as I try to run one model in more than one thread, I rapidly run into cuda error 700. Problem is less frequent with 1.17 or 1.16.3 than with 1.15.1 also.

@pedroBBastos
Copy link

I am having the same problem. Has anyone got a solution?

@leugenea
Copy link
Author

Still reproducible on the latest version to the moment -- 1.17.3.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:CUDA issues related to the CUDA execution provider
Projects
None yet
Development

No branches or pull requests

4 participants