-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
29 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,33 @@ | |
import tensorrt as trt | ||
from onnx import TensorProto, helper | ||
|
||
dynamic_dim_count = 0 | ||
|
||
# TRT uses "-1" to represent dynamic dimension | ||
# ORT uses symbolic name to represent dynamic dimension | ||
# Here we only do the conversion when there is any dynamic dimension in the shape | ||
def trt_shape_to_ort_shape(trt_data_shape): | ||
def has_dynamic_dim(trt_data_shape): | ||
if any( | ||
dim == -1 | ||
for dim in trt_data_shape | ||
): | ||
return True | ||
return False | ||
|
||
if not has_dynamic_dim(trt_data_shape): | ||
return trt_data_shape | ||
|
||
ort_data_shape = [] | ||
if has_dynamic_dim(trt_data_shape): | ||
for dim in trt_data_shape: | ||
if dim == -1: | ||
global dynamic_dim_count | ||
Check warning Code scanning / lintrunner RUFF/PLW0603 Warning
Using the global statement to update dynamic\_dim\_count is discouraged.
See https://beta.ruff.rs/docs/rules/ |
||
ort_data_shape.append("free_dim_" + str(dynamic_dim_count)) | ||
dynamic_dim_count += 1 | ||
else: | ||
ort_data_shape.append(dim) | ||
return ort_data_shape | ||
|
||
def trt_data_type_to_onnx_data_type(trt_data_type): | ||
if trt_data_type == trt.DataType.FLOAT: | ||
|
@@ -113,15 +140,15 @@ def main(): | |
for i in range(len(input_tensors)): | ||
model_inputs.append( | ||
helper.make_tensor_value_info( | ||
input_tensors[i], trt_data_type_to_onnx_data_type(input_tensor_types[i]), input_tensor_shapes[i] | ||
input_tensors[i], trt_data_type_to_onnx_data_type(input_tensor_types[i]), trt_shape_to_ort_shape(input_tensor_shapes[i]) | ||
) | ||
) | ||
|
||
model_outputs = [] | ||
for i in range(len(output_tensors)): | ||
model_outputs.append( | ||
helper.make_tensor_value_info( | ||
output_tensors[i], trt_data_type_to_onnx_data_type(output_tensor_types[i]), output_tensor_shapes[i] | ||
output_tensors[i], trt_data_type_to_onnx_data_type(output_tensor_types[i]), trt_shape_to_ort_shape(output_tensor_shapes[i]) | ||
) | ||
) | ||
|
||
|