Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
chilo-ms committed Oct 31, 2023
1 parent 348a963 commit f8099b1
Show file tree
Hide file tree
Showing 14 changed files with 1,879 additions and 1,036 deletions.
8 changes: 4 additions & 4 deletions include/onnxruntime/core/framework/op_kernel_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,10 @@ class OpKernelContext {
*/
AllocatorPtr GetAllocator(const OrtDevice& device) const;

#if defined(ENABLE_ATEN) || defined(USE_TENSORRT)
Status SetOutputMLValue(int index, const OrtValue& ort_value);
#endif

protected:
OpKernelContext(concurrency::ThreadPool* threadpool, const logging::Logger& logger, Stream* stream);

Expand All @@ -195,10 +199,6 @@ class OpKernelContext {
const OrtValue* GetImplicitInputMLValue(int index) const;
OrtValue* GetOutputMLValue(int index);

#ifdef ENABLE_ATEN
Status SetOutputMLValue(int index, const OrtValue& ort_value);
#endif

// Creates the OrtValue* based on the shape, if it does not exist
virtual OrtValue* OutputMLValue(int index, const TensorShape& shape);

Expand Down
9 changes: 9 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -4512,6 +4512,15 @@ struct OrtApi {
* \since Version 1.17.
*/
ORT_API2_STATUS(ReadOpAttr, _In_ const OrtOpAttr* op_attr, _In_ OrtOpAttrType type, _Inout_ void* data, _In_ size_t len, _Out_ size_t* out);

/** \brief Used for custom operators, set an output of a kernel
*
* \see ::OrtCustomOp
*
* \since Version 1.17.
*/
ORT_API2_STATUS(KernelContext_SetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index,
_In_ const OrtValue* ort_value);
};

/*
Expand Down
1 change: 1 addition & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2052,6 +2052,7 @@ struct KernelContext {
ConstValue GetInput(size_t index) const;
UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const;
UnownedValue GetOutput(size_t index, const std::vector<int64_t>& dims) const;
void SetOutput(size_t index, const OrtValue& ort_value);
void* GetGPUComputeStream() const;
Logger GetLogger() const;
OrtAllocator* GetAllocator(const OrtMemoryInfo& memory_info) const;
Expand Down
4 changes: 4 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -1634,6 +1634,10 @@ inline UnownedValue KernelContext::GetOutput(size_t index, const std::vector<int
return UnownedValue(out);
}

inline void KernelContext::SetOutput(size_t index, const OrtValue& ort_value) {
Ort::ThrowOnError(GetApi().KernelContext_SetOutput(ctx_, index, &ort_value));
}

inline void* KernelContext::GetGPUComputeStream() const {
void* out = nullptr;
Ort::ThrowOnError(GetApi().KernelContext_GetGPUComputeStream(ctx_, &out));
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/execution_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ IExecutionFrame::IExecutionFrame(const OrtValueNameIdxMap& ort_value_idx_map,

IExecutionFrame::~IExecutionFrame() = default;

#ifdef ENABLE_ATEN
#if defined(ENABLE_ATEN) || defined(USE_TENSORRT)
Status IExecutionFrame::SetOutputMLValue(int index, const OrtValue& ort_value) {
int ort_value_idx = GetNodeIdxToMLValueIdx(index);
if (ort_value_idx == NodeIndexInfo::kInvalidEntry || static_cast<size_t>(ort_value_idx) >= all_values_size_) {
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/execution_frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class IExecutionFrame {
const OrtValue* GetNodeInputOrOutputMLValue(int index) const;
OrtValue* GetMutableNodeInputOrOutputMLValue(int index);

#ifdef ENABLE_ATEN
#if defined(ENABLE_ATEN) || defined(USE_TENSORRT)
// Override the index-th output with ort_value
Status SetOutputMLValue(int index, const OrtValue& ort_value);
#endif
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/op_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ AllocatorPtr OpKernelContext::GetAllocator(const OrtDevice& device) const {
return execution_frame_->GetAllocator(device);
}

#ifdef ENABLE_ATEN
#if defined(ENABLE_ATEN) || defined(USE_TENSORRT)
Status OpKernelContext::SetOutputMLValue(int index, const OrtValue& ort_value) {
if (index < 0 || index >= OutputCount()) {
return Status(common::ONNXRUNTIME, common::FAIL,
Expand Down
2,751 changes: 1,723 additions & 1,028 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

Large diffs are not rendered by default.

83 changes: 83 additions & 0 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,61 @@ template <typename T>
using unique_pointer = std::unique_ptr<T, TensorrtInferDeleter>;
}; // namespace tensorrt_ptr

template <typename T>
inline T RoundUp(T m, T n) {
return ((m + n - 1) / n) * n;
}

//
// Class to allocate memory for outputs with data-dependent shapes. The sizes of those are unknown so pre-allocation is
// not possible.
//
class OutputAllocator : public nvinfer1::IOutputAllocator {
public:
OutputAllocator(OrtAllocator* alloc)
: allocator(alloc) {
}

void* reallocateOutput(
char const* tensorName, void* currentMemory, uint64_t size, uint64_t alignment) noexcept override {
// Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr
// even for empty tensors, so allocate a dummy byte.
size = std::max(size, static_cast<uint64_t>(1));
if (size > allocated_size) {
buffer = IAllocator::MakeUniquePtrFromOrtAllocator<void>(allocator, RoundUp(size, alignment));
allocated_size = size;
}
return buffer.get();
}

void* getBuffer() {
return buffer.get();
}

void notifyShape(char const* tensorName, nvinfer1::Dims const& dims) noexcept override {
output_shapes.reserve(dims.nbDims);
for (int i = 0; i < dims.nbDims; i++) {
output_shapes.push_back(dims.d[i]);
}
}

std::vector<int64_t>& getOutputShape() {
return output_shapes;
}

uint64_t getSize() {
return allocated_size;
}

~OutputAllocator() override {}

private:
OrtAllocator* allocator = nullptr;
IAllocatorUniquePtr<void> buffer;
uint64_t allocated_size = 0;
std::vector<int64_t> output_shapes;
};

using ShapeRangesMap = std::unordered_map<std::string, std::unordered_map<size_t, std::vector<std::vector<int64_t>>>>;

// Information to construct kernel function state.
Expand Down Expand Up @@ -145,6 +200,21 @@ struct TensorrtFuncState {
bool cuda_graph_enable = 0;
};

struct TensorrtShortFuncState {
AllocateFunc test_allocate_func = nullptr;
DestroyFunc test_release_func = nullptr;
AllocatorHandle allocator = nullptr;
std::unique_ptr<nvinfer1::ICudaEngine>* engine = nullptr;
std::unique_ptr<nvinfer1::IExecutionContext>* context = nullptr;
std::vector<std::unordered_map<std::string, size_t>> input_info;
std::vector<std::unordered_map<std::string, size_t>> output_info;
bool sync_stream_after_enqueue = false;
std::unordered_map<char const*, OutputAllocator*> dds_output_allocator_map;
bool context_memory_sharing_enable = false;
size_t* max_context_mem_size_ptr = nullptr;
OrtMutex* tensorrt_mu_ptr = nullptr;
};

// Holds important information for building valid ORT graph.
struct SubGraphContext {
std::unordered_set<std::string> output_args;
Expand All @@ -153,6 +223,7 @@ struct SubGraphContext {
};

using SubGraphContextMap = std::unordered_map<std::string, std::unique_ptr<SubGraphContext>>;
using DDSOutputAllocatorMap = std::unordered_map<char const*, OutputAllocator*>;

// Logical device representation.
class TensorrtExecutionProvider : public IExecutionProvider {
Expand Down Expand Up @@ -261,6 +332,7 @@ class TensorrtExecutionProvider : public IExecutionProvider {
std::unordered_map<std::string, std::vector<std::vector<int64_t>>> profile_opt_shapes_;
std::unordered_map<std::string, ShapeRangesMap> input_shape_ranges_; // The profile shape ranges that the engine is built with
std::unordered_map<std::string, std::vector<nvinfer1::IOptimizationProfile*>> profiles_;
std::unordered_map<std::string, DDSOutputAllocatorMap> dds_output_allocator_map_; // For DDS output tensor

// for external stream, we need to create its cudnn/cublass handle before cuda EP enable cuda graph capture
cudnnHandle_t external_cudnn_handle_ = nullptr;
Expand Down Expand Up @@ -452,6 +524,17 @@ class TensorrtExecutionProvider : public IExecutionProvider {
*/
bool IsLocalValue(const Graph& graph, const std::string& name) const;

Status CreateNodeComputeFromPrecompiledEngine(const GraphViewer& graph_body_viewer,
const Node& fused_node,
std::unordered_map<std::string, size_t>& input_map,
std::unordered_map<std::string, size_t>& output_map,
std::vector<NodeComputeInfo>& node_compute_funcs);
Status CreateNodeComputeFromOrtGraph(const GraphViewer& graph_body_viewer,
const Node& fused_node,
std::unordered_map<std::string, size_t>& input_map,
std::unordered_map<std::string, size_t>& output_map,
std::vector<NodeComputeInfo>& node_compute_funcs);

bool IsGraphCaptureAllowed() const;
void CaptureBegin();
void CaptureEnd();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <unordered_map>
#include <string>
#include <iostream>
#include <filesystem>
#include <experimental/filesystem>
#include "flatbuffers/idl.h"
#include "ort_trt_int8_cal_table.fbs.h"
Expand All @@ -13,6 +14,10 @@
#include "core/common/path_string.h"
#include "core/framework/murmurhash3.h"

static const std::string EP_CONTEXT_OP_TYPE = "EPContext";
static const std::string EP_CONTEXT_ATTR_EMBED_MODE = "embed_mode";
static const std::string EP_CONTEXT_ATTR_CACHE_CTX = "ep_cache_context";

namespace fs = std::experimental::filesystem;

namespace onnxruntime {
Expand Down Expand Up @@ -694,4 +699,30 @@ bool ParseProfileShapes(std::string profile_shapes_string, std::unordered_map<st

return true;
}


bool CheckPrecompiledEngine(const GraphViewer& graph) {
if (graph.NumberOfNodes() == 1 && graph.GetNode(0)->OpType() == EP_CONTEXT_OP_TYPE) {
return true;
}
return false;
}

/*
* The sanity check for EP context contrib op.
*/
bool IsValidEPContextNode(const GraphViewer& graph) {
assert(graph.NumberOfNodes() == 1);
assert(graph.GetNode(0)->OpType() == EP_CONTEXT_OP_TYPE);
auto node = graph.GetNode(0);
auto& attrs = node->GetAttributes();
if (attrs.count(EP_CONTEXT_ATTR_EMBED_MODE) > 0 && attrs.count(EP_CONTEXT_ATTR_CACHE_CTX) > 0) {
// ep_cache_context: payload of the execution provider context if embed_mode=1, or path to the context file if embed_mode=0
if (attrs.at(EP_CONTEXT_ATTR_EMBED_MODE).i() == 0 && !std::filesystem::exists(attrs.at(EP_CONTEXT_ATTR_CACHE_CTX).s())) {
LOGS_DEFAULT(ERROR) << "Can't find " << attrs.at(EP_CONTEXT_ATTR_CACHE_CTX).s() << " TensorRT engine";
return false;
}
}
return true;
}
} // namespace onnxruntime
16 changes: 16 additions & 0 deletions onnxruntime/core/session/custom_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,22 @@ ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetOutput, _Inout_ OrtKernelContext*
API_IMPL_END
};

ORT_API_STATUS_IMPL(OrtApis::KernelContext_SetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index, _In_ const OrtValue* ort_value) {
API_IMPL_BEGIN
#if defined(ENABLE_ATEN) || defined(USE_TENSORRT)
auto status = reinterpret_cast<onnxruntime::OpKernelContext*>(context)->SetOutputMLValue(gsl::narrow_cast<int>(index), *ort_value);
if (status.IsOK())
return nullptr;
return onnxruntime::ToOrtStatus(status);
#else
ORT_UNUSED_PARAMETER(context);
ORT_UNUSED_PARAMETER(index);
ORT_UNUSED_PARAMETER(ort_value);
return CreateStatus(ORT_FAIL, "TensorRT execution provider is not enabled in this build.");
#endif
API_IMPL_END
};

ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_string, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ char* out, _Inout_ size_t* size) {
API_IMPL_BEGIN
std::string value;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2721,6 +2721,7 @@ static constexpr OrtApi ort_api_1_to_17 = {
&OrtApis::ShapeInferContext_SetOutputTypeShape,
&OrtApis::SetSymbolicDimensions,
&OrtApis::ReadOpAttr,
&OrtApis::KernelContext_SetOutput,
};

// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ ORT_API_STATUS_IMPL(KernelContext_GetInputCount, _In_ const OrtKernelContext* co
ORT_API_STATUS_IMPL(KernelContext_GetOutputCount, _In_ const OrtKernelContext* context, _Out_ size_t* out);
ORT_API_STATUS_IMPL(KernelContext_GetInput, _In_ const OrtKernelContext* context, _In_ size_t index, _Out_ const OrtValue** out);
ORT_API_STATUS_IMPL(KernelContext_GetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count, _Out_ OrtValue** out);
ORT_API_STATUS_IMPL(KernelContext_SetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index, _In_ const OrtValue* ort_value);

// OrtTypeInfo methods
ORT_API_STATUS_IMPL(GetDenotationFromTypeInfo, _In_ const OrtTypeInfo*, _Out_ const char** const denotation, _Out_ size_t* len);
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/test/providers/cpu/nn/dropout_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ TEST(Dropout, WithOptionalOutputOpset10) {
test.AddInput<float>("X", dims, {1.0f, 2.0f, 3.0f, 5.0f});
test.AddOutput<float>("Y", dims, {1.0f, 2.0f, 3.0f, 5.0f});
test.AddOutput<bool>("mask", dims, {false, false, false, false});
test.Run();
// The fix in onnx-tensorrt parser for dropout onnx node is not included in TRT 8.6.1 but might be included in later ORT release.
// Simply skip this for now.
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}

TEST(Dropout, WithOptionalOutputOpset7) {
Expand Down

0 comments on commit f8099b1

Please sign in to comment.