Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorRT EP] Switch to enqueueV3 with support DDS output #17751

Closed
wants to merge 31 commits into from
Closed
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
de170d5
update
chilo-ms Sep 28, 2023
3affd2e
update
chilo-ms Sep 30, 2023
886bcce
update
chilo-ms Oct 2, 2023
ff9d3d2
update
chilo-ms Oct 10, 2023
ae69ca7
merge
chilo-ms Oct 10, 2023
09f2401
update
chilo-ms Oct 11, 2023
1e77cd1
Merge branch 'main' into chi/trt_enqueue_v3
chilo-ms Oct 11, 2023
5f82028
Merge branch 'main' into chi/trt_enqueue_v3
chilo-ms Oct 17, 2023
8f847ec
fix bug
chilo-ms Oct 17, 2023
35d54b8
fix bugs
chilo-ms Oct 17, 2023
5330dff
update
chilo-ms Oct 18, 2023
ed4849f
Merge branch 'main' into chi/trt_enqueue_v3
chilo-ms Oct 23, 2023
98b35ed
refactor
chilo-ms Oct 30, 2023
018a6b4
Merge branch 'main' into chi/trt_enqueue_v3
chilo-ms Oct 30, 2023
0e79992
fix format
chilo-ms Oct 30, 2023
bc7e206
fix minor bug
chilo-ms Nov 1, 2023
57ba208
remove redundant code
chilo-ms Nov 1, 2023
b1ec7cd
code refacotr
chilo-ms Nov 3, 2023
9653143
fix format
chilo-ms Nov 3, 2023
4e40cd3
Merge branch 'main' into chi/trt_enqueue_v3
chilo-ms Nov 3, 2023
ecc2566
update
chilo-ms Nov 4, 2023
7ee0ee0
update
chilo-ms Nov 5, 2023
98c8374
Merge branch 'main' into chi/trt_enqueue_v3
chilo-ms Nov 7, 2023
8567688
Merge branch 'main' into chi/trt_enqueue_v3
chilo-ms Nov 15, 2023
dc88e6b
Merge branch 'main' into chi/trt_enqueue_v3
chilo-ms Nov 19, 2023
2413e1d
Merge branch 'main' into chi/trt_enqueue_v3
chilo-ms Dec 1, 2023
99e79dd
Add INT32/INT64 and float/double conversion for DDS outputs
chilo-ms Dec 5, 2023
e15a8bd
update for adding INT32/INT64 and float/double conversion for DDS out…
chilo-ms Dec 5, 2023
8de13db
fix typo
chilo-ms Dec 5, 2023
27ea00e
fix bug for using local buffer
chilo-ms Dec 5, 2023
4ff9a85
code refactor and add cleanup for dds_output_allocator_map
chilo-ms Dec 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions include/onnxruntime/core/framework/op_kernel_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,10 @@ class OpKernelContext {
*/
AllocatorPtr GetAllocator(const OrtDevice& device) const;

#if defined(ENABLE_ATEN) || defined(USE_TENSORRT)
Status SetOutputMLValue(int index, const OrtValue& ort_value);
#endif

protected:
OpKernelContext(concurrency::ThreadPool* threadpool, const logging::Logger& logger, Stream* stream);

Expand All @@ -195,10 +199,6 @@ class OpKernelContext {
const OrtValue* GetImplicitInputMLValue(int index) const;
OrtValue* GetOutputMLValue(int index);

#ifdef ENABLE_ATEN
Status SetOutputMLValue(int index, const OrtValue& ort_value);
#endif

// Creates the OrtValue* based on the shape, if it does not exist
virtual OrtValue* OutputMLValue(int index, const TensorShape& shape);

Expand Down
9 changes: 9 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -4520,6 +4520,15 @@ struct OrtApi {
* \since Version 1.17.
*/
ORT_API2_STATUS(ReadOpAttr, _In_ const OrtOpAttr* op_attr, _In_ OrtOpAttrType type, _Inout_ void* data, _In_ size_t len, _Out_ size_t* out);

/** \brief Used for custom operators, set an output of a kernel
*
* \see ::OrtCustomOp
*
* \since Version 1.17.
*/
ORT_API2_STATUS(KernelContext_SetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KernelContext_SetOutput

why we want to expose it to CAPI? we have TRT based custom op?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KernelContext_SetOutput

why we want to expose it to CAPI? we have TRT based custom op?

compile api based EP's need to implement compute_func which only has access to public OrtKernelContext api , not the internal OpKernelContext api.
we need to use SetOutputMLValue() so that's why it's plumbed thru to the public api in this PR

_In_ const OrtValue* ort_value);
};

/*
Expand Down
1 change: 1 addition & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2052,6 +2052,7 @@ struct KernelContext {
ConstValue GetInput(size_t index) const;
UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const;
UnownedValue GetOutput(size_t index, const std::vector<int64_t>& dims) const;
void SetOutput(size_t index, const OrtValue& ort_value);
void* GetGPUComputeStream() const;
Logger GetLogger() const;
OrtAllocator* GetAllocator(const OrtMemoryInfo& memory_info) const;
Expand Down
4 changes: 4 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -1634,6 +1634,10 @@ inline UnownedValue KernelContext::GetOutput(size_t index, const std::vector<int
return UnownedValue(out);
}

inline void KernelContext::SetOutput(size_t index, const OrtValue& ort_value) {
Ort::ThrowOnError(GetApi().KernelContext_SetOutput(ctx_, index, &ort_value));
}

inline void* KernelContext::GetGPUComputeStream() const {
void* out = nullptr;
Ort::ThrowOnError(GetApi().KernelContext_GetGPUComputeStream(ctx_, &out));
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/execution_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ IExecutionFrame::IExecutionFrame(const OrtValueNameIdxMap& ort_value_idx_map,

IExecutionFrame::~IExecutionFrame() = default;

#ifdef ENABLE_ATEN
#if defined(ENABLE_ATEN) || defined(USE_TENSORRT)
Status IExecutionFrame::SetOutputMLValue(int index, const OrtValue& ort_value) {
int ort_value_idx = GetNodeIdxToMLValueIdx(index);
if (ort_value_idx == NodeIndexInfo::kInvalidEntry || static_cast<size_t>(ort_value_idx) >= all_values_size_) {
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/execution_frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class IExecutionFrame {
const OrtValue* GetNodeInputOrOutputMLValue(int index) const;
OrtValue* GetMutableNodeInputOrOutputMLValue(int index);

#ifdef ENABLE_ATEN
#if defined(ENABLE_ATEN) || defined(USE_TENSORRT)
// Override the index-th output with ort_value
Status SetOutputMLValue(int index, const OrtValue& ort_value);
#endif
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/op_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ AllocatorPtr OpKernelContext::GetAllocator(const OrtDevice& device) const {
return execution_frame_->GetAllocator(device);
}

#ifdef ENABLE_ATEN
#if defined(ENABLE_ATEN) || defined(USE_TENSORRT)
Status OpKernelContext::SetOutputMLValue(int index, const OrtValue& ort_value) {
if (index < 0 || index >= OutputCount()) {
return Status(common::ONNXRUNTIME, common::FAIL,
Expand Down
Loading
Loading