Skip to content

Commit

Permalink
update script
Browse files Browse the repository at this point in the history
  • Loading branch information
chilo-ms committed Nov 19, 2023
1 parent e2c3f16 commit 6605fd4
Showing 1 changed file with 29 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,33 @@
import tensorrt as trt
from onnx import TensorProto, helper

dynamic_dim_count = 0

# TRT uses "-1" to represent dynamic dimension
# ORT uses symbolic name to represent dynamic dimension
# Here we only do the conversion when there is any dynamic dimension in the shape
def trt_shape_to_ort_shape(trt_data_shape):
def has_dynamic_dim(trt_data_shape):
if any(
dim == -1
for dim in trt_data_shape
):
return True
return False

if not has_dynamic_dim(trt_data_shape):
return trt_data_shape

ort_data_shape = []
if has_dynamic_dim(trt_data_shape):
for dim in trt_data_shape:
if dim == -1:
global dynamic_dim_count

Check warning

Code scanning / lintrunner

RUFF/PLW0603 Warning

Using the global statement to update dynamic\_dim\_count is discouraged.
See https://beta.ruff.rs/docs/rules/
ort_data_shape.append("free_dim_" + str(dynamic_dim_count))
dynamic_dim_count += 1
else:
ort_data_shape.append(dim)
return ort_data_shape

def trt_data_type_to_onnx_data_type(trt_data_type):
if trt_data_type == trt.DataType.FLOAT:
Expand Down Expand Up @@ -113,15 +140,15 @@ def main():
for i in range(len(input_tensors)):
model_inputs.append(
helper.make_tensor_value_info(
input_tensors[i], trt_data_type_to_onnx_data_type(input_tensor_types[i]), input_tensor_shapes[i]
input_tensors[i], trt_data_type_to_onnx_data_type(input_tensor_types[i]), trt_shape_to_ort_shape(input_tensor_shapes[i])
)
)

model_outputs = []
for i in range(len(output_tensors)):
model_outputs.append(
helper.make_tensor_value_info(
output_tensors[i], trt_data_type_to_onnx_data_type(output_tensor_types[i]), output_tensor_shapes[i]
output_tensors[i], trt_data_type_to_onnx_data_type(output_tensor_types[i]), trt_shape_to_ort_shape(output_tensor_shapes[i])
)
)

Expand Down

0 comments on commit 6605fd4

Please sign in to comment.