diff --git a/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py b/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py index 15614fd215b73..e95610bbfeec0 100644 --- a/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py +++ b/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py @@ -10,15 +10,13 @@ dynamic_dim_count = 0 + # TRT uses "-1" to represent dynamic dimension # ORT uses symbolic name to represent dynamic dimension # Here we only do the conversion when there is any dynamic dimension in the shape def trt_shape_to_ort_shape(trt_data_shape): def has_dynamic_dim(trt_data_shape): - if any( - dim == -1 - for dim in trt_data_shape - ): + if any(dim == -1 for dim in trt_data_shape): return True return False @@ -36,6 +34,7 @@ def has_dynamic_dim(trt_data_shape): ort_data_shape.append(dim) return ort_data_shape + def trt_data_type_to_onnx_data_type(trt_data_type): if trt_data_type == trt.DataType.FLOAT: return TensorProto.FLOAT @@ -140,7 +139,9 @@ def main(): for i in range(len(input_tensors)): model_inputs.append( helper.make_tensor_value_info( - input_tensors[i], trt_data_type_to_onnx_data_type(input_tensor_types[i]), trt_shape_to_ort_shape(input_tensor_shapes[i]) + input_tensors[i], + trt_data_type_to_onnx_data_type(input_tensor_types[i]), + trt_shape_to_ort_shape(input_tensor_shapes[i]), ) ) @@ -148,7 +149,9 @@ def main(): for i in range(len(output_tensors)): model_outputs.append( helper.make_tensor_value_info( - output_tensors[i], trt_data_type_to_onnx_data_type(output_tensor_types[i]), trt_shape_to_ort_shape(output_tensor_shapes[i]) + output_tensors[i], + trt_data_type_to_onnx_data_type(output_tensor_types[i]), + trt_shape_to_ort_shape(output_tensor_shapes[i]), ) )