From 8f9912fb5d1e37a24c9b64b468a09829d7816606 Mon Sep 17 00:00:00 2001 From: Peter McAughan Date: Thu, 2 Nov 2023 23:33:30 +0000 Subject: [PATCH] Lintrunner fixes --- .../tools/transformers/convert_generation.py | 9 +++++--- .../models/llama/convert_to_onnx.py | 22 +++++++++++-------- .../transformers/models/llama/llama_inputs.py | 4 +++- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index 02e1b2e94c7f8..dd716d28910b5 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -1271,7 +1271,10 @@ def find_past_seq_len_usage(subg: GraphProto): nodes_to_remove.append(shape_node) return tensor_names_to_rename, nodes_to_remove -def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads: int = 0, world_size: int = 1, window_size: int = 0): + +def replace_mha_with_gqa( + model: OnnxModel, past_seq_len_input: str, kv_num_heads: int = 0, world_size: int = 1, window_size: int = 0 +): past_seq_len = past_seq_len_input if past_seq_len not in model.get_graphs_input_names(): # Add model input for past sequence length @@ -1287,7 +1290,7 @@ def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads num_heads_mha = att.i if window_size: gqa_node = onnx.helper.make_node( - "GroupQueryAttention", + "GroupQueryAttention", inputs=[ node.input[0], # query node.input[1], # key @@ -1331,7 +1334,7 @@ def update_decoder_subgraph_output_cross_attention(subg: GraphProto): input_self_past_0 = 1 # w/wo attention mask, w/wo hidden_state graph_input_names = [gi.name for gi in subg.input] - while input_self_past_0 replace_mha_with_gqa 3 and not graph_input_names[input_self_past_0].startswith("past"): + while input_self_past_0 < 3 and not graph_input_names[input_self_past_0].startswith("past"): input_self_past_0 += 1 output_self_present_0 = 1 diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 55ac35e016933..9a0f086cfb7eb 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -14,7 +14,7 @@ from onnx_model import OnnxModel from optimizer import optimize_model from packaging import version -from transformers import LlamaConfig, LlamaForCausalLM, PretrainedConfig, AutoConfig +from transformers import AutoConfig, LlamaConfig, LlamaForCausalLM, PretrainedConfig from onnxruntime import quantization as ort_quantization from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer @@ -391,7 +391,7 @@ def run_torchscript_merged_export( # Optimize the model as FP32 -def optimize_export(config: LlamaConfig, input_path: str, output_path: str, remove_model = True): +def optimize_export(config: LlamaConfig, input_path: str, output_path: str, remove_model=True): from fusion_options import FusionOptions optimization_options = FusionOptions("gpt2") @@ -407,7 +407,7 @@ def optimize_export(config: LlamaConfig, input_path: str, output_path: str, remo ) model_opt.save_model_to_file(output_path, use_external_data_format=True) logger.info(f"The ONNX model at {input_path} has been successfully optimized and saved at {output_path}!") - if (remove_model): + if remove_model: remove_existing_model(input_path) @@ -438,7 +438,9 @@ def convert_to_float16( return new_paths -def use_group_query_attention(config: LlamaConfig, fp16_model_opt: OnnxModel, world_size: int = 1, window_size: int = 0): +def use_group_query_attention( + config: LlamaConfig, fp16_model_opt: OnnxModel, world_size: int = 1, window_size: int = 0 +): # Replace MultiHeadAttention with GroupQueryAttention and remove attention mask nodes fp16_model_opt = replace_mha_with_gqa( fp16_model_opt, "past_sequence_length", config.num_key_value_heads, world_size, window_size @@ -447,6 +449,7 @@ def use_group_query_attention(config: LlamaConfig, fp16_model_opt: OnnxModel, wo fp16_model_opt.update_graph(allow_remove_graph_inputs=True) return fp16_model_opt + def smooth_quant( args: argparse.Namespace, decoder_model_fp32_path: str, @@ -539,17 +542,18 @@ def remove_existing_files(output_path: str): os.remove(filepath) logger.warning(f"Removed {filepath}") + def optimize_optimum(config: PretrainedConfig, args): - tmp_file = os.path.join(args.output, args.model_name+".tmp.onnx") - output_file = os.path.join(args.output, args.model_name+".onnx") + tmp_file = os.path.join(args.output, args.model_name + ".tmp.onnx") + output_file = os.path.join(args.output, args.model_name + ".onnx") optimize_export(config, args.input, tmp_file, remove_model=False) logger.info(f"Model successfully optimized to {tmp_file}") opt_model = OnnxModel(onnx.load_model(tmp_file, load_external_data=True)) if args.precision == Precision.FLOAT16: opt_model.convert_float_to_float16(keep_io_types=False) - window_size = 0 if not hasattr(config, "sliding_window") else config.sliding_window + window_size = 0 if not hasattr(config, "sliding_window") else config.sliding_window opt_model = use_group_query_attention(config, opt_model, args.world_size, window_size) - logger.info(f"Model successfully fused and quantized to FP16!") + logger.info("Model successfully fused and quantized to FP16!") opt_model.save_model_to_file(output_file, use_external_data_format=True) logger.info(f"Output model successfully saved to {output_file}") logger.info(f"Removing {tmp_file}") @@ -745,7 +749,7 @@ def main(): # Second predicate is for comparing nightly (ex: 2.2.0.dev20230920 vs 2.2.0) since first predicate is false # in that scenario. It can be removed when torch v2.2.0 is released in stable. logger.warning(f"Detected PyTorch version {torch.__version__}. Please upgrade and use v2.2.0 or newer.") - #return + # return args = get_args() setup_logger(args.verbose) diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py index 84b57c5049abc..6423a58ec2bc6 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py @@ -222,7 +222,9 @@ def get_msft_sample_inputs( # Create past_key_values # Each is of shape (batch_size, num_heads, past_sequence_length, head_size) -def get_past_kv_inputs(config: PretrainedConfig, batch_size: int, past_seq_len: int, use_fp16: bool, world_size: int = 1): +def get_past_kv_inputs( + config: PretrainedConfig, batch_size: int, past_seq_len: int, use_fp16: bool, world_size: int = 1 +): num_heads, head_size = config.num_key_value_heads // world_size, config.hidden_size // config.num_attention_heads torch_dtype = torch.float16 if use_fp16 else torch.float32 past_kv = [