Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
chilo-ms committed Nov 4, 2023
1 parent 34a86d7 commit 9631f73
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 43 deletions.
44 changes: 2 additions & 42 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,46 +365,6 @@ std::unique_lock<OrtMutex> TensorrtExecutionProvider::GetApiLock() const {
return std::unique_lock<OrtMutex>(singleton);
}

Status GetShapeOfShapeTensor(Ort::ConstValue& input_tensor,
std::vector<int32_t>& shape_values,
nvinfer1::ICudaEngine* trt_engine,
int binding_index,
cudaStream_t stream) {
auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo();
const auto tensor_shapes = tensor_info.GetShape();
const auto tensor_type = tensor_info.GetElementType();
nvinfer1::Dims dims = trt_engine->getBindingDimensions(static_cast<int>(binding_index));
int nb_dims = dims.nbDims;
int shape_size = nb_dims == 0 ? 1 : static_cast<int>(tensor_shapes[0]); // The shape of the "shape tensor" is either zero dimension (scalar) or 1-dimension
shape_values.resize(shape_size, 1);

switch (tensor_type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: {
auto input = std::make_unique<int32_t[]>(shape_size);
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input.get(), input_tensor.GetTensorData<int32_t>(), shape_size * sizeof(int32_t), cudaMemcpyDeviceToHost, stream));
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream));
for (int j = 0; j < shape_size; ++j) {
shape_values[j] = input[j];
}
break;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: {
auto input = std::make_unique<int64_t[]>(shape_size);
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input.get(), input_tensor.GetTensorData<int64_t>(), shape_size * sizeof(int64_t), cudaMemcpyDeviceToHost, stream));
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream));
for (int j = 0; j < shape_size; ++j) {
shape_values[j] = static_cast<int32_t>(input[j]);
}
break;
}
default: {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT shape tensor data type: " + std::to_string(tensor_type) + " not supported.");
}
}
return Status::OK();
}

/*
* Get the shape of "shape tensor" input
*/
Expand Down Expand Up @@ -2239,7 +2199,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
// If the model consists of only a single "EPContext" contrib op, it means TRT EP can fetch the precompiled engine info from the node and

Check warning on line 2199 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L2199

Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:2199:  Line ends in whitespace.  Consider deleting these extra spaces.  [whitespace/end_of_line] [4]

Check warning on line 2199 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L2199

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:2199:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
// load the engine directly without having to go through the processes of graph proto reconstruction, calling TRT parser and engine compilation.

Check warning on line 2200 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L2200

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:2200:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
// So, simply return the ComputeCapability here.
if (IsFusedGraphHasCtxNode(graph)) {
if (GraphHasCtxNode(graph)) {
if (IsValidCtxNode(graph)) {
SubGraph_t supported_node_vector = {{0}, false};
std::unique_ptr<IndexedSubGraph> sub_graph = GetSubGraph(supported_node_vector, graph, TRTGenerateId(graph), 0);
Expand Down Expand Up @@ -3715,7 +3675,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
}

Status status;
if (IsFusedGraphHasCtxNode(graph_body_viewer)) {
if (GraphHasCtxNode(graph_body_viewer)) {
status = CreateNodeComputeFromPrecompiledEngine(graph_body_viewer, fused_node, input_map, output_map, node_compute_funcs);

Check warning on line 3679 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L3679

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:3679:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
} else {
status = CreateNodeComputeFromOrtGraph(graph_body_viewer, fused_node, input_map, output_map, node_compute_funcs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ bool ParseProfileShapes(std::string profile_shapes_string, std::unordered_map<st
*
* Note: Please see more details about "EPContext" contrib op in contrib_defs.cc
*/
bool IsFusedGraphHasCtxNode(const GraphViewer& graph) {
bool GraphHasCtxNode(const GraphViewer& graph) {
if (graph.NumberOfNodes() == 1) {
for (int i = 0; i < graph.MaxNodeIndex(); ++i) {
auto node = graph.GetNode(i);
Expand Down

0 comments on commit 9631f73

Please sign in to comment.