diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 1f8b4c348bc89..184a60aa041fb 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1052,6 +1052,8 @@ Status BindKernelOutput(Ort::KernelContext& ctx, char const* output_name, size_t output_index, size_t output_type, + std::vector>& scratch_buffers, + std::unordered_map& buffers, OrtAllocator* alloc, cudaStream_t stream) { auto allocator = allocator_map[output_name]; @@ -1101,9 +1103,10 @@ Status BindKernelOutput(Ort::KernelContext& ctx, output_dim_size *= shape[i]; } } - IAllocatorUniquePtr buffer = IAllocator::MakeUniquePtrFromOrtAllocator(alloc, output_dim_size * sizeof(int64_t)); - cuda::Impl_Cast(stream, reinterpret_cast(allocator->getBuffer()), buffer.get(), output_dim_size); - Ort::ThrowOnError(Ort::GetApi().CreateTensorWithDataAsOrtValue(mem_info, buffer.get(), output_dim_size * sizeof(int64_t), + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, output_dim_size * sizeof(int64_t))); + buffers[output_name] = scratch_buffers.back().get(); + cuda::Impl_Cast(stream, reinterpret_cast(allocator->getBuffer()), reinterpret_cast(buffers[output_name]), output_dim_size); + Ort::ThrowOnError(Ort::GetApi().CreateTensorWithDataAsOrtValue(mem_info, buffers[output_name], output_dim_size * sizeof(int64_t), shape.data(), shape.size(), Ort::TypeToTensorType::type, &out)); break; } @@ -1119,9 +1122,10 @@ Status BindKernelOutput(Ort::KernelContext& ctx, output_dim_size *= shape[i]; } } - IAllocatorUniquePtr buffer = IAllocator::MakeUniquePtrFromOrtAllocator(alloc, output_dim_size * sizeof(double)); - cuda::Impl_Cast(stream, reinterpret_cast(allocator->getBuffer()), buffer.get(), output_dim_size); - Ort::ThrowOnError(Ort::GetApi().CreateTensorWithDataAsOrtValue(mem_info, buffer.get(), output_dim_size * sizeof(double), + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, output_dim_size * sizeof(double))); + buffers[output_name] = scratch_buffers.back().get(); + cuda::Impl_Cast(stream, reinterpret_cast(allocator->getBuffer()), reinterpret_cast(buffers[output_name]), output_dim_size); + Ort::ThrowOnError(Ort::GetApi().CreateTensorWithDataAsOrtValue(mem_info, buffers[output_name], output_dim_size * sizeof(double), shape.data(), shape.size(), Ort::TypeToTensorType::type, &out)); break; } @@ -3354,7 +3358,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorsecond; } - auto status = BindKernelOutput(ctx, &mem_info, dds_output_allocator_map, output_name, output_index, output_type, alloc, stream); + auto status = BindKernelOutput(ctx, &mem_info, dds_output_allocator_map, output_name, output_index, output_type, scratch_buffers, buffers, alloc, stream); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage()); }