From c51093e9fc0c113529233d8e061db82c447852fb Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Wed, 20 Sep 2023 17:20:58 +0000 Subject: [PATCH] lint f for python --- .../python/tools/transformers/convert_generation.py | 12 ------------ .../transformers/models/whisper/whisper_chain.py | 4 +--- 2 files changed, 1 insertion(+), 15 deletions(-) diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index 25404377e9b17..c1c709d6d759b 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -1272,18 +1272,6 @@ def find_past_seq_len_usage(subg: GraphProto): return tensor_names_to_rename, nodes_to_remove -def dim_of(dim_proto): - return dim_proto.dim_param if hasattr(dim_proto, "dim_param") and dim_proto.dim_param else dim_proto.dim_value - - -def shape_of(vi): - return tuple([dim_of(dim_proto) for dim_proto in vi.type.tensor_type.shape.dim]) - - -def data_type_of(vi): - return vi.type.tensor_type.elem_type - - def update_decoder_subgraph_output_cross_attention(subg: GraphProto): input_self_past_0 = 1 # w/wo attention mask, w/wo hidden_state diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index 82bbe802bed32..a1ed0c7ed5ca2 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -3,8 +3,6 @@ import sys import onnx -from benchmark_helper import Precision -from convert_generation import get_shared_initializers, update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha from onnx import TensorProto, helper from transformers import WhisperConfig @@ -236,7 +234,7 @@ def chain_model(args): if ( (pgi.name not in beam_graph_input_names) and (pgi.name not in beam_graph_output_names) - and (not pgi.name == "cross_qk") + and (pgi.name != "cross_qk") ): beam_graph.input.extend([pgi]) beam_graph.output.extend(post_qk_graph.output)