Skip to content

Commit

Permalink
fix bug for using local buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
chilo-ms committed Dec 5, 2023
1 parent 8de13db commit 27ea00e
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,8 @@ Status BindKernelOutput(Ort::KernelContext& ctx,
char const* output_name,
size_t output_index,
size_t output_type,
std::vector<IAllocatorUniquePtr<void>>& scratch_buffers,
std::unordered_map<char const*, void*>& buffers,
OrtAllocator* alloc,
cudaStream_t stream) {
auto allocator = allocator_map[output_name];
Expand Down Expand Up @@ -1101,9 +1103,10 @@ Status BindKernelOutput(Ort::KernelContext& ctx,
output_dim_size *= shape[i];
}
}
IAllocatorUniquePtr<int64_t> buffer = IAllocator::MakeUniquePtrFromOrtAllocator<int64_t>(alloc, output_dim_size * sizeof(int64_t));
cuda::Impl_Cast<int32_t, int64_t>(stream, reinterpret_cast<int32_t*>(allocator->getBuffer()), buffer.get(), output_dim_size);
Ort::ThrowOnError(Ort::GetApi().CreateTensorWithDataAsOrtValue(mem_info, buffer.get(), output_dim_size * sizeof(int64_t),
scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, output_dim_size * sizeof(int64_t)));

Check warning on line 1106 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L1106

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1106:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
buffers[output_name] = scratch_buffers.back().get();
cuda::Impl_Cast<int32_t, int64_t>(stream, reinterpret_cast<int32_t*>(allocator->getBuffer()), reinterpret_cast<int64_t*>(buffers[output_name]), output_dim_size);

Check warning on line 1108 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L1108

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1108:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
Ort::ThrowOnError(Ort::GetApi().CreateTensorWithDataAsOrtValue(mem_info, buffers[output_name], output_dim_size * sizeof(int64_t),

Check warning on line 1109 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L1109

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1109:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
shape.data(), shape.size(), Ort::TypeToTensorType<int64_t>::type, &out));

Check warning on line 1110 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L1110

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1110:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
break;
}
Expand All @@ -1119,9 +1122,10 @@ Status BindKernelOutput(Ort::KernelContext& ctx,
output_dim_size *= shape[i];
}
}
IAllocatorUniquePtr<double> buffer = IAllocator::MakeUniquePtrFromOrtAllocator<double>(alloc, output_dim_size * sizeof(double));
cuda::Impl_Cast<float, double>(stream, reinterpret_cast<float*>(allocator->getBuffer()), buffer.get(), output_dim_size);
Ort::ThrowOnError(Ort::GetApi().CreateTensorWithDataAsOrtValue(mem_info, buffer.get(), output_dim_size * sizeof(double),
scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, output_dim_size * sizeof(double)));

Check warning on line 1125 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L1125

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1125:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
buffers[output_name] = scratch_buffers.back().get();
cuda::Impl_Cast<float, double>(stream, reinterpret_cast<float*>(allocator->getBuffer()), reinterpret_cast<double*>(buffers[output_name]), output_dim_size);

Check warning on line 1127 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L1127

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1127:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
Ort::ThrowOnError(Ort::GetApi().CreateTensorWithDataAsOrtValue(mem_info, buffers[output_name], output_dim_size * sizeof(double),

Check warning on line 1128 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L1128

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1128:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
shape.data(), shape.size(), Ort::TypeToTensorType<double>::type, &out));

Check warning on line 1129 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L1129

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1129:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
break;
}
Expand Down Expand Up @@ -3354,7 +3358,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd

// Assign TRT output back to ORT output
// (1) Bind TRT DDS output to ORT kernel context output. (It needs to wait until enqueueV3 is finished)
// (2) Cast TRT INT32 output to ORT INT64 output or TRT double output to float output
// (2) Cast TRT INT32 output to ORT INT64 output or TRT float output to double output
for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) {
char const* output_name = output_binding_names[i];

Expand All @@ -3370,7 +3374,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
if (index_iter != output_indexes.end()) {
output_index = index_iter->second;
}
auto status = BindKernelOutput(ctx, &mem_info, dds_output_allocator_map, output_name, output_index, output_type, alloc, stream);
auto status = BindKernelOutput(ctx, &mem_info, dds_output_allocator_map, output_name, output_index, output_type, scratch_buffers, buffers, alloc, stream);

Check warning on line 3377 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L3377

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:3377:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
if (status != Status::OK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage());
}
Expand Down

0 comments on commit 27ea00e

Please sign in to comment.