Skip to content

Commit

Permalink
code refactor and add cleanup for dds_output_allocator_map
Browse files Browse the repository at this point in the history
  • Loading branch information
chilo-ms committed Dec 5, 2023
1 parent 27ea00e commit 4ff9a85
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
23 changes: 15 additions & 8 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1095,7 +1095,7 @@ Status BindKernelOutput(Ort::KernelContext& ctx,
// The allocation buffer holds the INT32 output data since TRT doesn't support INT64 but INT32.
// So, we need to cast the data from INT32 to INT64 and then set INT64 output data to kernel context.
SafeInt<int> output_dim_size(1);
for (int i = 0; i < shape.size(); ++i) {
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] == 0) {
output_dim_size = 1;
break;
Expand All @@ -1104,17 +1104,17 @@ Status BindKernelOutput(Ort::KernelContext& ctx,
}
}
scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, output_dim_size * sizeof(int64_t)));

Check warning on line 1106 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L1106

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1106:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
buffers[output_name] = scratch_buffers.back().get();
cuda::Impl_Cast<int32_t, int64_t>(stream, reinterpret_cast<int32_t*>(allocator->getBuffer()), reinterpret_cast<int64_t*>(buffers[output_name]), output_dim_size);
Ort::ThrowOnError(Ort::GetApi().CreateTensorWithDataAsOrtValue(mem_info, buffers[output_name], output_dim_size * sizeof(int64_t),
auto data = scratch_buffers.back().get();
cuda::Impl_Cast<int32_t, int64_t>(stream, reinterpret_cast<int32_t*>(allocator->getBuffer()), reinterpret_cast<int64_t*>(data), output_dim_size);

Check warning on line 1108 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L1108

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1108:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
Ort::ThrowOnError(Ort::GetApi().CreateTensorWithDataAsOrtValue(mem_info, data, output_dim_size * sizeof(int64_t),
shape.data(), shape.size(), Ort::TypeToTensorType<int64_t>::type, &out));

Check warning on line 1110 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L1110

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1110:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
break;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: {
// The allocation buffer holds the FLOAT output data since TRT doesn't support DOUBLE but FLOAT.
// So, we need to cast the data from FLOAT to DOUBEL and then set DOUBLE output data to kernel context.
SafeInt<int> output_dim_size(1);
for (int i = 0; i < shape.size(); ++i) {
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] == 0) {
output_dim_size = 1;
break;
Expand All @@ -1123,9 +1123,9 @@ Status BindKernelOutput(Ort::KernelContext& ctx,
}
}
scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, output_dim_size * sizeof(double)));

Check warning on line 1125 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L1125

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1125:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
buffers[output_name] = scratch_buffers.back().get();
cuda::Impl_Cast<float, double>(stream, reinterpret_cast<float*>(allocator->getBuffer()), reinterpret_cast<double*>(buffers[output_name]), output_dim_size);
Ort::ThrowOnError(Ort::GetApi().CreateTensorWithDataAsOrtValue(mem_info, buffers[output_name], output_dim_size * sizeof(double),
auto data = scratch_buffers.back().get();
cuda::Impl_Cast<float, double>(stream, reinterpret_cast<float*>(allocator->getBuffer()), reinterpret_cast<double*>(data), output_dim_size);

Check warning on line 1127 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L1127

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1127:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
Ort::ThrowOnError(Ort::GetApi().CreateTensorWithDataAsOrtValue(mem_info, data, output_dim_size * sizeof(double),
shape.data(), shape.size(), Ort::TypeToTensorType<double>::type, &out));

Check warning on line 1129 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L1129

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1129:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
break;
}
Expand Down Expand Up @@ -1659,6 +1659,13 @@ TensorrtExecutionProvider::~TensorrtExecutionProvider() {
// We can't get api inside destructor so that's why we duplicate the code here.
delete static_cast<OrtAllocatorImpl*>(alloc_);
}

for (auto iter_outer = dds_output_allocator_map_.begin(); iter_outer != dds_output_allocator_map_.end(); ++iter_outer) {

Check warning on line 1663 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L1663

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1663:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
auto inner_map = iter_outer->second;
for (auto iter_inner = inner_map.begin(); iter_inner != inner_map.end(); ++iter_inner) {
delete iter_inner->second;
}
}
}

bool TensorrtExecutionProvider::IsGraphCaptureEnabled() const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ class TensorrtExecutionProvider : public IExecutionProvider {
std::unordered_map<std::string, std::vector<std::vector<int64_t>>> profile_opt_shapes_;
std::unordered_map<std::string, ShapeRangesMap> input_shape_ranges_; // The profile shape ranges that the engine is built with
std::unordered_map<std::string, std::vector<nvinfer1::IOptimizationProfile*>> profiles_;
std::unordered_map<std::string, DDSOutputAllocatorMap> dds_output_allocator_map_; // For DDS output tensor
std::unordered_map<std::string, DDSOutputAllocatorMap> dds_output_allocator_map_; // For DDS output tensor. TODO: Make DDSOutputAllocatorMap use unique_ptr

Check warning on line 323 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h#L323

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h:323:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

// for external stream, we need to create its cudnn/cublass handle before cuda EP enable cuda graph capture
cudnnHandle_t external_cudnn_handle_ = nullptr;
Expand Down

0 comments on commit 4ff9a85

Please sign in to comment.