diff --git a/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py b/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py index a2292679f3ea2..15614fd215b73 100644 --- a/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py +++ b/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py @@ -8,6 +8,33 @@ import tensorrt as trt from onnx import TensorProto, helper +dynamic_dim_count = 0 + +# TRT uses "-1" to represent dynamic dimension +# ORT uses symbolic name to represent dynamic dimension +# Here we only do the conversion when there is any dynamic dimension in the shape +def trt_shape_to_ort_shape(trt_data_shape): + def has_dynamic_dim(trt_data_shape): + if any( + dim == -1 + for dim in trt_data_shape + ): + return True + return False + + if not has_dynamic_dim(trt_data_shape): + return trt_data_shape + + ort_data_shape = [] + if has_dynamic_dim(trt_data_shape): + for dim in trt_data_shape: + if dim == -1: + global dynamic_dim_count + ort_data_shape.append("free_dim_" + str(dynamic_dim_count)) + dynamic_dim_count += 1 + else: + ort_data_shape.append(dim) + return ort_data_shape def trt_data_type_to_onnx_data_type(trt_data_type): if trt_data_type == trt.DataType.FLOAT: @@ -113,7 +140,7 @@ def main(): for i in range(len(input_tensors)): model_inputs.append( helper.make_tensor_value_info( - input_tensors[i], trt_data_type_to_onnx_data_type(input_tensor_types[i]), input_tensor_shapes[i] + input_tensors[i], trt_data_type_to_onnx_data_type(input_tensor_types[i]), trt_shape_to_ort_shape(input_tensor_shapes[i]) ) ) @@ -121,7 +148,7 @@ def main(): for i in range(len(output_tensors)): model_outputs.append( helper.make_tensor_value_info( - output_tensors[i], trt_data_type_to_onnx_data_type(output_tensor_types[i]), output_tensor_shapes[i] + output_tensors[i], trt_data_type_to_onnx_data_type(output_tensor_types[i]), trt_shape_to_ort_shape(output_tensor_shapes[i]) ) )