Skip to content

Commit

Permalink
fix format
Browse files Browse the repository at this point in the history
  • Loading branch information
chilo-ms committed Nov 19, 2023
1 parent 6605fd4 commit 699d538
Showing 1 changed file with 9 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,13 @@

dynamic_dim_count = 0


# TRT uses "-1" to represent dynamic dimension
# ORT uses symbolic name to represent dynamic dimension
# Here we only do the conversion when there is any dynamic dimension in the shape
def trt_shape_to_ort_shape(trt_data_shape):
def has_dynamic_dim(trt_data_shape):
if any(
dim == -1
for dim in trt_data_shape
):
if any(dim == -1 for dim in trt_data_shape):
return True
return False

Expand All @@ -36,6 +34,7 @@ def has_dynamic_dim(trt_data_shape):
ort_data_shape.append(dim)
return ort_data_shape


def trt_data_type_to_onnx_data_type(trt_data_type):
if trt_data_type == trt.DataType.FLOAT:
return TensorProto.FLOAT
Expand Down Expand Up @@ -140,15 +139,19 @@ def main():
for i in range(len(input_tensors)):
model_inputs.append(
helper.make_tensor_value_info(
input_tensors[i], trt_data_type_to_onnx_data_type(input_tensor_types[i]), trt_shape_to_ort_shape(input_tensor_shapes[i])
input_tensors[i],
trt_data_type_to_onnx_data_type(input_tensor_types[i]),
trt_shape_to_ort_shape(input_tensor_shapes[i]),
)
)

model_outputs = []
for i in range(len(output_tensors)):
model_outputs.append(
helper.make_tensor_value_info(
output_tensors[i], trt_data_type_to_onnx_data_type(output_tensor_types[i]), trt_shape_to_ort_shape(output_tensor_shapes[i])
output_tensors[i],
trt_data_type_to_onnx_data_type(output_tensor_types[i]),
trt_shape_to_ort_shape(output_tensor_shapes[i]),
)
)

Expand Down

0 comments on commit 699d538

Please sign in to comment.