Skip to content

Commit

Permalink
[TensorRT EP] Add unit test for user provided cuda stream (#17974)
Browse files Browse the repository at this point in the history
Add a unit test for testing user provided CUDA stream
  • Loading branch information
chilo-ms authored Oct 24, 2023
1 parent 4ffd022 commit 555b2af
Showing 1 changed file with 126 additions and 9 deletions.
135 changes: 126 additions & 9 deletions onnxruntime/test/shared_lib/test_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2832,6 +2832,132 @@ TEST(CApiTest, ConfigureCudaArenaAndDemonstrateMemoryArenaShrinkage) {
#endif

#ifdef USE_TENSORRT
TEST(CApiTest, TestExternalCUDAStreamWithIOBinding) {
const auto& api = Ort::GetApi();
Ort::SessionOptions session_options;

OrtTensorRTProviderOptionsV2* trt_options;
ASSERT_TRUE(api.CreateTensorRTProviderOptions(&trt_options) == nullptr);
std::unique_ptr<OrtTensorRTProviderOptionsV2, decltype(api.ReleaseTensorRTProviderOptions)>
rel_trt_options(trt_options, api.ReleaseTensorRTProviderOptions);

// updating provider option with user provided compute stream
cudaStream_t compute_stream = nullptr;
void* user_compute_stream = nullptr;
cudaStreamCreate(&compute_stream);
ASSERT_TRUE(api.UpdateTensorRTProviderOptionsWithValue(rel_trt_options.get(), "user_compute_stream", compute_stream) == nullptr);
ASSERT_TRUE(api.GetTensorRTProviderOptionsByName(rel_trt_options.get(), "user_compute_stream", &user_compute_stream) == nullptr);
ASSERT_TRUE(user_compute_stream == (void*)compute_stream);

ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_TensorRT_V2(
static_cast<OrtSessionOptions*>(session_options),
rel_trt_options.get()) == nullptr);

Ort::Session session(*ort_env, MODEL_URI, session_options);
Ort::MemoryInfo info_cuda("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault);

const std::array<int64_t, 2> x_shape = {3, 2};
std::array<float, 3 * 2> x_values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};

/*
* Use cudaMallocHost() (pinned memory allocation) to create input/output tensors
*/
float* input_data;
cudaMallocHost(&input_data, 3 * 2 * sizeof(float));
ASSERT_NE(input_data, nullptr);
cudaMemcpy(input_data, x_values.data(), sizeof(float) * x_values.size(), cudaMemcpyHostToDevice);

std::cout << "pinned memory allocation" << std::endl;
std::cout << "input tesnor:" << std::endl;
for (int i = 0; i < 6; i++) {
std::cout << input_data[i] << std::endl;
}

// Create an OrtValue tensor backed by data on CUDA memory
Ort::Value bound_x = Ort::Value::CreateTensor(info_cuda, reinterpret_cast<float*>(input_data), x_values.size(),
x_shape.data(), x_shape.size());

const std::array<int64_t, 2> expected_y_shape = {3, 2};
std::array<float, 3 * 2> expected_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f};

float* output_data;
cudaMallocHost(&output_data, 3 * 2 * sizeof(float));
ASSERT_NE(output_data, nullptr);

// Create an OrtValue tensor backed by data on CUDA memory
Ort::Value bound_y = Ort::Value::CreateTensor(info_cuda, reinterpret_cast<float*>(output_data),
expected_y.size(), expected_y_shape.data(), expected_y_shape.size());

// Create IoBinding for inputs and outputs.
Ort::IoBinding binding(session);
binding.BindInput("X", bound_x);
binding.BindOutput("Y", bound_y);

/*
* Use cudaMalloc() (pageable memory allocation first and then implicit pinned memory allocation) to create input/output tensors
*/
float* input_data_2;
cudaMalloc(&input_data_2, 3 * 2 * sizeof(float));
ASSERT_NE(input_data_2, nullptr);
cudaMemcpy(input_data_2, x_values.data(), sizeof(float) * x_values.size(), cudaMemcpyHostToDevice);

// Create an OrtValue tensor backed by data on CUDA memory
Ort::Value bound_x_2 = Ort::Value::CreateTensor(info_cuda, reinterpret_cast<float*>(input_data_2), x_values.size(),
x_shape.data(), x_shape.size());

float* output_data_2;
cudaMalloc(&output_data_2, 3 * 2 * sizeof(float));
ASSERT_NE(output_data_2, nullptr);

// Create an OrtValue tensor backed by data on CUDA memory
Ort::Value bound_y_2 = Ort::Value::CreateTensor(info_cuda, reinterpret_cast<float*>(output_data_2),
expected_y.size(), expected_y_shape.data(), expected_y_shape.size());

// Create IoBinding for inputs and outputs.
Ort::IoBinding binding_2(session);
binding_2.BindInput("X", bound_x_2);
binding_2.BindOutput("Y", bound_y_2);

// Run with first iobindings
session.Run(Ort::RunOptions(), binding);

// Check the values against the bound raw memory (needs copying from device to host first)
std::array<float, 3 * 2> y_values;
cudaMemcpy(y_values.data(), output_data, sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost);

std::cout << "pinned memory allocation" << std::endl;
std::cout << "output: " << std::endl;
for (auto y : y_values) {
std::cout << y << std::endl;
}
ASSERT_THAT(y_values, ::testing::ContainerEq(expected_y));

// Run with second iobindings
session.Run(Ort::RunOptions(), binding_2);

// Check the values against the bound raw memory (needs copying from device to host first)
cudaMemcpy(y_values.data(), output_data_2, sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost);

std::cout << "pageable memory allocation" << std::endl;
std::cout << "output: " << std::endl;
for (auto y : y_values) {
std::cout << y << std::endl;
}
ASSERT_THAT(y_values, ::testing::ContainerEq(expected_y));

// Clean up
binding.ClearBoundInputs();
binding.ClearBoundOutputs();
binding_2.ClearBoundInputs();
binding_2.ClearBoundOutputs();

cudaFreeHost(input_data);
cudaFreeHost(output_data);
cudaFree(input_data_2);
cudaFree(output_data_2);
cudaStreamDestroy(compute_stream);
}

class CApiTensorRTTest : public testing::Test, public ::testing::WithParamInterface<std::string> {};

// This test uses CreateTensorRTProviderOptions/UpdateTensorRTProviderOptions APIs to configure and create a TensorRT Execution Provider
Expand All @@ -2849,15 +2975,6 @@ TEST_P(CApiTensorRTTest, TestConfigureTensorRTProviderOptions) {
ASSERT_TRUE(api.CreateTensorRTProviderOptions(&trt_options) == nullptr);
std::unique_ptr<OrtTensorRTProviderOptionsV2, decltype(api.ReleaseTensorRTProviderOptions)> rel_trt_options(trt_options, api.ReleaseTensorRTProviderOptions);

// Only test updating provider option with user provided compute stream
cudaStream_t compute_stream = nullptr;
void* user_compute_stream = nullptr;
cudaStreamCreateWithFlags(&compute_stream, cudaStreamNonBlocking);
ASSERT_TRUE(api.UpdateTensorRTProviderOptionsWithValue(rel_trt_options.get(), "user_compute_stream", compute_stream) == nullptr);
ASSERT_TRUE(api.GetTensorRTProviderOptionsByName(rel_trt_options.get(), "user_compute_stream", &user_compute_stream) == nullptr);
ASSERT_TRUE(user_compute_stream == (void*)compute_stream);
cudaStreamDestroy(compute_stream);

const char* engine_cache_path = "./trt_engine_folder";

std::vector<const char*> keys{"device_id", "has_user_compute_stream", "trt_fp16_enable", "trt_int8_enable", "trt_engine_cache_enable",
Expand Down

0 comments on commit 555b2af

Please sign in to comment.