Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tensorRT runtime provider is not thread-safe #19275

Closed
r0l1 opened this issue Jan 25, 2024 · 20 comments
Closed

tensorRT runtime provider is not thread-safe #19275

r0l1 opened this issue Jan 25, 2024 · 20 comments
Assignees
Labels
ep:CUDA issues related to the CUDA execution provider ep:TensorRT issues related to TensorRT execution provider

Comments

@r0l1
Copy link

r0l1 commented Jan 25, 2024

Describe the issue

Doing concurrent session->Run calls without a mutex lock causes corrupt output matrices. The CPU and CUDA runtime providers work as excepted. Using the tensorRT provider causes this issue.

To reproduce

Forward matrices concurrently to one session without a mutex lock. Use the tensorRT runtime provider. The output matrices are corrupt and invalid.

Urgency

This issue is a release blocker and we must downgrade to ONNXRuntime 1.12.1 with tensorRT 8.4.3.

Platform

Linux

OS Version

linux 6.7.0 kernel

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

1.16.3

ONNX Runtime API

C++

Architecture

X64

Execution Provider

TensorRT

Execution Provider Library Version

CUDA 11.8.0 TensorRT 8.6.1.6

@github-actions github-actions bot added ep:CUDA issues related to the CUDA execution provider ep:TensorRT issues related to TensorRT execution provider labels Jan 25, 2024
@jywu-msft
Copy link
Member

would it be possible for you to provide us a repro test case so we can look into it deeper from our end?
for 1.16.3 we shouldn't have known concurrency issues. we have tests for multi threaded inference with TensorRT EP.

@r0l1
Copy link
Author

r0l1 commented Jan 25, 2024

I can't provide a test case immediately. Our C++ code is wrapped by Go and I can't provide source to everything. However I could provide an onnx model with a fitting image if that helps?

@chilo-ms
Copy link
Contributor

Do you know whether the whole model can be run by TRT EP? or there are multiple partitions and some of them are assigned to CUDA EP or CPU?

@chilo-ms
Copy link
Contributor

I can't provide a test case immediately. Our C++ code is wrapped by Go and I can't provide source to everything. However I could provide an onnx model with a fitting image if that helps?

yes, please provide the onnx model.

@r0l1
Copy link
Author

r0l1 commented Jan 26, 2024

Yes, the model can be run by TRT. We first thought, that this is a TensorRT problem and tested it against the TensorRT backend without onnxruntime code.

I uploaded the onnx model and a simple sample image for forwarding. This is a segmentation task and the output must not be 0. I checked the output matrix with a countNonZero function and failed the test, if the image contains only zero values. Switch between the CPU and TensorRT runtime provider and you will see the different behavior. Important: this will only occur with concurrent forwarding.

Model forwarding info

input:
    name: input
    type: float32
    shape: [1, 3, 80, 128]
    shape-format: NCHW
    sub-pixel-layout: BGR
    mean: {r: 127.5, g: 127.5, b: 127.5}
    std: {r: 127.5, g: 127.5, b: 127.5}

output:
    name: output
    type: int64
    shape: [1, 1, 80, 128]

data.zip

@r0l1
Copy link
Author

r0l1 commented Jan 26, 2024

Just tested this against the onnxruntime master code. This issue is still present. Please let me know, if you need help to reproduce this issue.

@chilo-ms
Copy link
Contributor

If you can provide a simple c++ program for us to repro that would be great.

I'm using the "-c" (number of concurrent run) of onnx_test_runner(will be in release folder when you build from source) to try to repro the issue.

BTW, could you help try with the TRT EP provider options "trt_cuda_graph_enable" to enable cuda graph to see whether the concurrent issue still exists?

@jywu-msft
Copy link
Member

to repro, can we feed all zeros of shape [1,3, 80,128] into input. and if executed concurrently, the output matrix will be zero (when it should be nonzero). is that the correct understanding? (then we don't need the .png image you provided?)
some other clarifying questions, do you bind input/output to gpu memory? and do you use cuda streams?

@r0l1
Copy link
Author

r0l1 commented Jan 26, 2024

The input image (input.png) is required. Load and resize it to 128x80, apply normalization with mean & std. If all input values are zero, the output will also be zero.
Cuda streams are not used and the input matrices are bound to the CPU (host memory).

Here is a part of our forward code.

void ONNXRuntime::forward(const vector<ModelMat>& inputs, vector<ModelMat>& outputs) {
    const int inputSize  = inputs.size();
    const int outputSize = outputs.size();

    // Check arguments.
    if (inputSize != inputNames_.size()) {
        throw std::runtime_error(
            "ONNXRuntime::forward: expected " + to_string(inputNames_.size()) + " inputs, got " + to_string(inputSize)
        );
    } else if (outputSize != outputNames_.size()) {
        throw std::runtime_error(
            "ONNXRuntime::forward: expected " + to_string(outputNames_.size()) + " outputs, got " + to_string(outputSize)
        );
    }

    // Create the tensors.
    const vector<Ort::Value> inputTensors  = createTensors_(inputs, inputTypes_);
    vector<Ort::Value>       outputTensors = createTensors_(outputs, outputTypes_);

    // TODO: remove once tensorrt issue is fixed: https://github.com/microsoft/onnxruntime/issues/19275
    if (runMxRequired_) {
        unique_lock<std::mutex> lock(runMx_);
        session_->Run(runOpts_, inputNamesCstr_.data(), inputTensors.data(), inputSize, outputNamesCstr_.data(), outputTensors.data(), outputSize);
        return;
    }

    // Forward pass.
    // Mutex is not needed: https://github.com/microsoft/onnxruntime/issues/114
    session_->Run(runOpts_, inputNamesCstr_.data(), inputTensors.data(), inputSize, outputNamesCstr_.data(), outputTensors.data(), outputSize);
}

createTensors_ creates onnxruntime tensors (CPU) from OpenCV matrices (shared ref). But that's nothing special and just a lot of code.

Here is our crash testing code in go:

  1. Load the source image & onnxruntime model with TensorRT provider
  2. Start NumCPUs routines and feed the same image to the model/session (normalization happens during the forward)
  3. Check output and fail if all values are zero.
func init() {
	App.AddCommand(&grumble.Command{
		Name: "crash",
		Help: "crash",
		Run:  runCrash,
		Args: func(a *grumble.Args) {
			a.String("src", "source image")
			a.String("npack", "model")
			a.String("cache", "cache")
		},
	})
}

func runCrash(ctx *grumble.Context) (err error) {
	cl := ctx.App.CloserOneWay()
	defer func() {
		cErr := cl.Close()
		if cErr != nil && err == nil {
			err = cErr
		}
	}()
	go cos.ListenForInterrupts(cl)

	// Our model options.
	lo := nlib.DefaultSemanticSegmentationOptions()
	lo.Backend = nlib.ModelBackend_ONNXRuntime_TensorRT
	lo.GPUID = 0
	lo.ConfThresh = 0.5
	lo.Workers = runtime.NumCPU()
	lo.MatPoolSize = 3 * runtime.NumCPU()
	lo.Interpolation = nlib.InterArea
	lo.ResizeMaskToInput = false
	lo.TRTEnableFP16 = true
	lo.TRTCacheDir = ctx.Args.String("cache")

	// Load the locate model.
	lp, err := npack.OpenFile(cl.CloserTwoWay(), ctx.Args.String("npack"))
	if err != nil {
		return
	}
	model, err := nlib.NewSemanticSegmentationModel(cl.CloserTwoWay(), lp, lo)
	if err != nil {
		return
	}

	src := nlib.NewMat()
	defer src.Free()

	err = src.ReadFromFile(ctx.Args.String("src"))
	if err != nil {
		return
	}
	src.SetReadOnly(true)

	start := time.Now()

	var wg sync.WaitGroup
	for i := 0; i < runtime.NumCPU(); i++ {
		wg.Add(1)
		go func() {
			defer wg.Done()

			mask := nlib.NewMat()
			defer mask.Free()

			for i := 0; i < 20; i++ {
				err := model.Segment(src, mask)
				if err != nil {
					panic(err)
				}

				c, err := mask.CountNonZero()
				if err != nil {
					panic(err)
				} else if c == 0 {
					panic(fmt.Errorf("zero"))
				}
			}
		}()
	}

	wg.Wait()

	fmt.Println("took:", time.Since(start))

	return
}

@jywu-msft
Copy link
Member

jywu-msft commented Jan 26, 2024

it'd help save us time if you can just provide us the input as either onnx TensorProto or numpy array. thanks!
also you can try enabling cuda graph suggestion that @chilo-ms mentioned

@r0l1
Copy link
Author

r0l1 commented Jan 26, 2024

Sorry I missed @chilo-ms reply. Will be back with more info soon.

@jywu-msft
Copy link
Member

we think we've identified the root cause. it's a bit complicated. we had previously fixed it but had to back out that fix due to another issue. unfortunately, we weren't able to resume that work until now. we will now work on implementing a full solution and provide you an ETA.
you can help us confirm our hypothesis building from source commitid c969237 and testing your application again.

@chilo-ms
Copy link
Contributor

@r0l1
Even though the doc says "Multi-threaded usage is currently not supported in cuda graph", we can still try it with TRT EP.
First, it requires a warm up run to capture cuda graph and this should be run by one thread (main thread).
And then, we can run session.Run() concurrently.

@r0l1
Copy link
Author

r0l1 commented Jan 27, 2024

@jywu-msft That commit fixes the issue. Thanks for working on a final fix :)

@chilo-ms Thank you for the info. We always did a synchronous warmup straight after loading the model, It would be great, if the thread-safety of the session Run would be documented in the docs. Last time I checked, I only found some issues talking about it.

I started to prepare a small C++ file to reproduce this issue. Do you still need it or is there a way to add this to the onnxruntime tests?

chilo-ms added a commit that referenced this issue Jan 30, 2024
Given that InferenceSession::Run() is guaranteed to be thread-safe
meaning multiple threads can call this function concurrently,
TRT EP needs to carefully take care of concurrency here, if not,
following concurrent issue might happen:
    

- It's suggested that to perform inference concurrently in multiple
streams, use one trt execution context per stream.
In the design of TRT EP (Not apply per-thread context implementation)
and if multiple threads are calling InferenceSession::Run()
concurrently, the trt execution context instance is shared by all the
threads and each thread aquires different stream from ORT.
So TRT EP will end up having one trt execution context using multiple
streams which is not suggested.
But, since the whole compute_func() is protected by the lock and if
cudaStreamSynchronize() is enforced here, one trt execution context per
stream is guaranteed.
     
Therefore, TRT EP needs to call cudaStreamSynchronize() at
compute_func() which means to wait until stream has completed all
operations to prevent the concurrent

github isse: #19275
@chilo-ms
Copy link
Contributor

@r0l1
We had a fix for this concurrent issue and plan to make it into ORT 1.17
Could you also help verify it? thank you.

YUNQIUGUO pushed a commit that referenced this issue Jan 30, 2024
Given that InferenceSession::Run() is guaranteed to be thread-safe
meaning multiple threads can call this function concurrently,
TRT EP needs to carefully take care of concurrency here, if not,
following concurrent issue might happen:
    

- It's suggested that to perform inference concurrently in multiple
streams, use one trt execution context per stream.
In the design of TRT EP (Not apply per-thread context implementation)
and if multiple threads are calling InferenceSession::Run()
concurrently, the trt execution context instance is shared by all the
threads and each thread aquires different stream from ORT.
So TRT EP will end up having one trt execution context using multiple
streams which is not suggested.
But, since the whole compute_func() is protected by the lock and if
cudaStreamSynchronize() is enforced here, one trt execution context per
stream is guaranteed.
     
Therefore, TRT EP needs to call cudaStreamSynchronize() at
compute_func() which means to wait until stream has completed all
operations to prevent the concurrent

github isse: #19275
@r0l1
Copy link
Author

r0l1 commented Jan 30, 2024

@chilo-ms that was fast. Just checked out the latest main branch and the issue is fixed now. Thank you very much!

@r0l1 r0l1 closed this as completed Jan 30, 2024
@jywu-msft
Copy link
Member

jywu-msft commented Jan 30, 2024

@chilo-ms that was fast. Just checked out the latest main branch and the issue is fixed now. Thank you very much!

@r0l1 thanks for reporting the issue and helping us repro/test/validate!

@manickavela29
Copy link
Contributor

Hi Everyone,
Quick info,
which version of onnxruntime did this patch go into,
I am working on sherpa-onnx using onnxrt 1.17.1
Observing some mismatched output with TensorRT enabled but holds good with CPU,CUDA
If this patch is in any later version will have a look and verify!

@r0l1
Copy link
Author

r0l1 commented Jun 21, 2024

Should have been the 1.17 release. Check your tensorRT version. The newest version from NVIDIA has some problems.

@manickavela29
Copy link
Contributor

manickavela29 commented Jul 1, 2024

Hi @r0l1,

Thanks for your suggestion.

Since I am using tensorrt with onnxruntime-gpu 1.17.1 (by doc tensorrt 8.6 is compatible)
apart from that I have just installed below in my docker file

docker image : FROM nvidia/cuda:11.8.0-runtime-ubuntu22.04

libcudnn8=8.7.0.84-1+cuda11.8
RUN apt-get -y install libnvinfer8=8.6.1.6-1+cuda11.8 \
    libnvinfer-plugin8=8.6.1.6-1+cuda11.8 \
    libnvonnxparsers8=8.6.1.6-1+cuda11.8

Onnxruntime has dependency on these libs of tensorrt was my understanding,
therefore not installing tensorrt separately, I think this should hold good.
Let me know if this is otherwise.

Since 1.17 had this patch 1.17.1 should be holding good,
any thoughts on tensorrt version for which this fix was?
would be of great help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:CUDA issues related to the CUDA execution provider ep:TensorRT issues related to TensorRT execution provider
Projects
None yet
Development

No branches or pull requests

4 participants