From 65423a2875f624afb3e7c3c430d19f9ae534531a Mon Sep 17 00:00:00 2001 From: Emmanuel Date: Mon, 1 Jul 2024 18:09:18 +0000 Subject: [PATCH 1/9] dynamo export success --- .../models/llama/convert_to_onnx.py | 63 +++++++++++-------- 1 file changed, 36 insertions(+), 27 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 8a33544654e05..615e25f65d39f 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -15,6 +15,7 @@ import onnx import torch +import tempfile from benchmark_helper import Precision, prepare_environment, setup_logger from convert_generation import replace_mha_with_gqa from dist_settings import barrier, get_rank, get_size, init_dist @@ -149,35 +150,45 @@ def run_dynamo_export( config.capture_scalar_outputs = True # Dummy values for export - batch_size, sequence_length = 2, 8 - device = torch.device("cpu") + batch_size, sequence_length, past_sequence_length = 2, 8, 0 + device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu") # Export decoder_model.onnx - input_ids, attn_mask, pos_ids = get_sample_inputs(l_config, device, batch_size, sequence_length) - temp_dir = args.output # tempfile.TemporaryDirectory() - temp_path = os.path.join(temp_dir, "temp.onnx") # os.path.join(temp_dir.name, "temp.onnx") - torch.onnx.dynamo_export( - llama, input_ids, attn_mask, pos_ids, export_options=torch.onnx.ExportOptions(dynamic_shapes=True) - ).save(temp_path) + # input_ids, attn_mask, pos_ids = get_sample_inputs(l_config, device, batch_size, sequence_length) + # temp_dir = tempfile.TemporaryDirectory() + # temp_path = os.path.join(temp_dir.name, "temp.onnx") + # torch.onnx.dynamo_export( + # llama, input_ids, attn_mask, pos_ids, export_options=torch.onnx.ExportOptions(dynamic_shapes=True) + # ).save(temp_path) - # Check decoder_model.onnx and save all external data to one file - onnx.checker.check_model(temp_path) - onnx.shape_inference.infer_shapes_path(temp_path) + # # Check decoder_model.onnx and save all external data to one file + # onnx.checker.check_model(temp_path) + # onnx.shape_inference.infer_shapes_path(temp_path) + + # output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx") + # onnx_model = onnx.load_model(temp_path, load_external_data=True) + # save_onnx_model(onnx_model, output_path, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx.data") + # del onnx_model + # temp_dir.cleanup() + + + temp_name = args.model_name.lower().replace("-", "").replace("_", "") + max_sequence_length = 16384 if "codellama" in temp_name else 4096 if "llama2" in temp_name else 2048 - output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx") - onnx_model = onnx.load_model(temp_path, load_external_data=True) - save_onnx_model(onnx_model, output_path, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx.data") - del onnx_model - os.system( - f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}" - ) # temp_dir.cleanup() # Export decoder_with_past_model.onnx - input_ids, attn_mask, pos_ids, past_kv = get_sample_with_past_kv_inputs( - l_config, device, batch_size, sequence_length, world_size=world_size + input_ids, attn_mask, pos_ids, past_kv = get_merged_sample_with_past_kv_inputs( + l_config, + device, + batch_size, + sequence_length, + past_sequence_length, + max_seq_len=max_sequence_length, + use_fp16=False, + world_size=world_size, ) - temp_dir = args.output # tempfile.TemporaryDirectory() - temp_path = os.path.join(temp_dir, "temp.onnx") # os.path.join(temp_dir.name, "temp.onnx") + temp_dir = tempfile.TemporaryDirectory() + temp_path = os.path.join(temp_dir.name, "temp.onnx") torch.onnx.dynamo_export( llama, input_ids, attn_mask, pos_ids, past_kv, export_options=torch.onnx.ExportOptions(dynamic_shapes=True) ).save(temp_path) @@ -190,9 +201,7 @@ def run_dynamo_export( onnx_model = onnx.load_model(temp_path, load_external_data=True) save_onnx_model(onnx_model, output_path, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx.data") del onnx_model - os.system( - f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}" - ) # temp_dir.cleanup() + temp_dir.cleanup() logger.info(f"The {args.model_name} ONNX model has been successfully created with the Dynamo exporter!") @@ -869,7 +878,7 @@ def main(): # Export to ONNX if missing_separate_exports or missing_merged_export: - if args.use_dynamo_export and missing_separate_exports: + if args.use_dynamo_export and missing_merged_export: logger.warning("Please ensure you have installed PyTorch, ONNX, and ONNX Script as follows.") logger.warning("Step 1 - PyTorch nightly: https://pytorch.org/get-started/locally/") logger.warning("Step 2 - ONNX weekly: https://pypi.org/project/onnx-weekly/") @@ -1056,4 +1065,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file From 16c94c8a2fc0598ac37e6864a1cc9860ab6fc3fa Mon Sep 17 00:00:00 2001 From: Emmanuel Date: Tue, 23 Jul 2024 23:12:44 +0000 Subject: [PATCH 2/9] changes to dynamo --- .../models/llama/convert_to_onnx.py | 208 +++++++++--------- 1 file changed, 104 insertions(+), 104 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 615e25f65d39f..aebaa0b916876 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -878,7 +878,7 @@ def main(): # Export to ONNX if missing_separate_exports or missing_merged_export: - if args.use_dynamo_export and missing_merged_export: + if args.use_dynamo_export: logger.warning("Please ensure you have installed PyTorch, ONNX, and ONNX Script as follows.") logger.warning("Step 1 - PyTorch nightly: https://pytorch.org/get-started/locally/") logger.warning("Step 2 - ONNX weekly: https://pypi.org/project/onnx-weekly/") @@ -911,109 +911,109 @@ def main(): decoder_merged_model_fp32_opt_path, ] - # Run the optimizer script - logger.info("Optimizing models...") - for orig_path, opt_path in zip(old_paths, new_paths): - if os.path.exists(orig_path): - optimize_export(args, l_config, input_path=orig_path, output_path=opt_path, world_size=world_size) - - # Re-assign default FP32 model paths as their optimized versions - decoder_model_fp32_path = decoder_model_fp32_opt_path - decoder_with_past_model_fp32_path = decoder_with_past_model_fp32_opt_path - decoder_merged_model_fp32_path = decoder_merged_model_fp32_opt_path - old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path] - - logger.info( - f"The {args.model_name} ONNX model has been successfully optimized with the ORT transformer optimizer script!" - ) - - # Change precision of exported models from FP32 - if args.precision == Precision.FLOAT16: - new_paths = convert_to_float16(args, old_paths, rank) - - elif args.precision == Precision.INT8: - decoder_model_int8_path = os.path.join( - args.output, f"rank_{rank}_{args.model_name}_decoder_model_int8.onnx" - ) - decoder_with_past_model_int8_path = os.path.join( - args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int8.onnx" - ) - decoder_merged_model_int8_path = os.path.join( - args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int8.onnx" - ) - new_paths = [decoder_model_int8_path, decoder_with_past_model_int8_path, decoder_merged_model_int8_path] - - if args.quantization_method == "smooth_quant": - if not args.no_merged: - logger.error("SmoothQuant must be used on separately exported models") - else: - logger.info( - f"Quantizing {decoder_model_fp32_path} and {decoder_with_past_model_fp32_path} to int8" - ) - smooth_quant(args, old_paths[0], old_paths[1], new_paths[0], new_paths[1]) - - elif args.quantization_method == "quantize_dynamic": - logger.warning( - "The `quantize_dynamic` method is deprecated in favor of `smooth_quant` instead. Precision loss may be high with `quantize_dynamic`." - ) - - logger.info("Quantizing to int8...") - for fp32_path, int8_path in zip(old_paths, new_paths): - if os.path.exists(fp32_path): - ort_quantization.quantize_dynamic( - fp32_path, - int8_path, - op_types_to_quantize=( - ["MatMul", "Gemm", "Gather"] - if args.quantize_embedding_layer - else ["MatMul", "Gemm"] - ), - per_channel=args.quantize_per_channel, - reduce_range=args.quantize_reduce_range, - use_external_data_format=True, - extra_options={"MatMulConstBOnly": True}, - ) - logger.info( - f"The ONNX model at {fp32_path} has been quantized to int8 and saved at {int8_path}!" - ) - remove_existing_model(decoder_model_fp32_path) - - logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!") - - else: - raise Exception(f"Could not recognize {args.quantization_method} as a quantization method") - - elif args.precision == Precision.INT4: - if args.execution_provider != "cpu": - old_paths = convert_to_float16(args, old_paths, rank) - - decoder_model_int4_path = os.path.join( - args.output, f"rank_{rank}_{args.model_name}_decoder_model_int4.onnx" - ) - decoder_with_past_model_int4_path = os.path.join( - args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int4.onnx" - ) - decoder_merged_model_int4_path = os.path.join( - args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int4.onnx" - ) - new_paths = [decoder_model_int4_path, decoder_with_past_model_int4_path, decoder_merged_model_int4_path] - - for fp_path, int4_path in zip(old_paths, new_paths): - if os.path.exists(fp_path): - model = onnx.load_model(fp_path, load_external_data=True) - quant = MatMul4BitsQuantizer( - model=model, - block_size=args.block_size, - is_symmetric=True, - accuracy_level=args.int4_accuracy_level, - nodes_to_exclude=[], - ) - quant.process() - quant.model.save_model_to_file(int4_path, use_external_data_format=True) - del model - del quant - logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!") - remove_existing_model(fp_path) + # # Run the optimizer script, runs the torch as well + # logger.info("Optimizing models...") + # for orig_path, opt_path in zip(old_paths, new_paths): + # if os.path.exists(orig_path): + # optimize_export(args, l_config, input_path=orig_path, output_path=opt_path, world_size=world_size) + + # # Re-assign default FP32 model paths as their optimized versions + # decoder_model_fp32_path = decoder_model_fp32_opt_path + # decoder_with_past_model_fp32_path = decoder_with_past_model_fp32_opt_path + # decoder_merged_model_fp32_path = decoder_merged_model_fp32_opt_path + # old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path] + + # logger.info( + # f"The {args.model_name} ONNX model has been successfully optimized with the ORT transformer optimizer script!" + # ) + + # # Change precision of exported models from FP32 + # if args.precision == Precision.FLOAT16: + # new_paths = convert_to_float16(args, old_paths, rank) + + # elif args.precision == Precision.INT8: + # decoder_model_int8_path = os.path.join( + # args.output, f"rank_{rank}_{args.model_name}_decoder_model_int8.onnx" + # ) + # decoder_with_past_model_int8_path = os.path.join( + # args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int8.onnx" + # ) + # decoder_merged_model_int8_path = os.path.join( + # args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int8.onnx" + # ) + # new_paths = [decoder_model_int8_path, decoder_with_past_model_int8_path, decoder_merged_model_int8_path] + + # if args.quantization_method == "smooth_quant": + # if not args.no_merged: + # logger.error("SmoothQuant must be used on separately exported models") + # else: + # logger.info( + # f"Quantizing {decoder_model_fp32_path} and {decoder_with_past_model_fp32_path} to int8" + # ) + # smooth_quant(args, old_paths[0], old_paths[1], new_paths[0], new_paths[1]) + + # elif args.quantization_method == "quantize_dynamic": + # logger.warning( + # "The `quantize_dynamic` method is deprecated in favor of `smooth_quant` instead. Precision loss may be high with `quantize_dynamic`." + # ) + + # logger.info("Quantizing to int8...") + # for fp32_path, int8_path in zip(old_paths, new_paths): + # if os.path.exists(fp32_path): + # ort_quantization.quantize_dynamic( + # fp32_path, + # int8_path, + # op_types_to_quantize=( + # ["MatMul", "Gemm", "Gather"] + # if args.quantize_embedding_layer + # else ["MatMul", "Gemm"] + # ), + # per_channel=args.quantize_per_channel, + # reduce_range=args.quantize_reduce_range, + # use_external_data_format=True, + # extra_options={"MatMulConstBOnly": True}, + # ) + # logger.info( + # f"The ONNX model at {fp32_path} has been quantized to int8 and saved at {int8_path}!" + # ) + # remove_existing_model(decoder_model_fp32_path) + + # logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!") + + # else: + # raise Exception(f"Could not recognize {args.quantization_method} as a quantization method") + + # elif args.precision == Precision.INT4: + # if args.execution_provider != "cpu": + # old_paths = convert_to_float16(args, old_paths, rank) + + # decoder_model_int4_path = os.path.join( + # args.output, f"rank_{rank}_{args.model_name}_decoder_model_int4.onnx" + # ) + # decoder_with_past_model_int4_path = os.path.join( + # args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int4.onnx" + # ) + # decoder_merged_model_int4_path = os.path.join( + # args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int4.onnx" + # ) + # new_paths = [decoder_model_int4_path, decoder_with_past_model_int4_path, decoder_merged_model_int4_path] + + # for fp_path, int4_path in zip(old_paths, new_paths): + # if os.path.exists(fp_path): + # model = onnx.load_model(fp_path, load_external_data=True) + # quant = MatMul4BitsQuantizer( + # model=model, + # block_size=args.block_size, + # is_symmetric=True, + # accuracy_level=args.int4_accuracy_level, + # nodes_to_exclude=[], + # ) + # quant.process() + # quant.model.save_model_to_file(int4_path, use_external_data_format=True) + # del model + # del quant + # logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!") + # remove_existing_model(fp_path) barrier() logger.info("Verifying parity on all ONNX models created") From 6af5245f41d0727e93f865dbfdba948fc3e0887b Mon Sep 17 00:00:00 2001 From: Emmanuel Date: Mon, 12 Aug 2024 22:20:44 +0000 Subject: [PATCH 3/9] Added changes to enable dynamo exporter --- .../models/llama/convert_to_onnx.py | 4 +- .../transformers/models/llama/llama_inputs.py | 53 +++++++++++++++++++ .../transformers/models/llama/llama_parity.py | 16 +++++- .../transformers/models/llama/llama_torch.py | 1 + 4 files changed, 70 insertions(+), 4 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index aebaa0b916876..e550c1ae7a73c 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -153,7 +153,7 @@ def run_dynamo_export( batch_size, sequence_length, past_sequence_length = 2, 8, 0 device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu") - # Export decoder_model.onnx + # Export decoder_model.onnx, commented out for now to use dynamo_export # input_ids, attn_mask, pos_ids = get_sample_inputs(l_config, device, batch_size, sequence_length) # temp_dir = tempfile.TemporaryDirectory() # temp_path = os.path.join(temp_dir.name, "temp.onnx") @@ -911,7 +911,7 @@ def main(): decoder_merged_model_fp32_opt_path, ] - # # Run the optimizer script, runs the torch as well + # # Run the optimizer script, runs the torch as well. Keeping this block commented makes sure only Dynamo export is used. # logger.info("Optimizing models...") # for orig_path, opt_path in zip(old_paths, new_paths): # if os.path.exists(orig_path): diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py index 39f0588436d2e..1cb3f3587bc29 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py @@ -28,6 +28,7 @@ def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool): # input_ids: (batch_size, sequence_length) # attention_mask: (batch_size, sequence_length) # position_ids: (batch_size, sequence_length) + def get_sample_inputs( config: AutoConfig, device: torch.device, @@ -171,6 +172,58 @@ def get_merged_sample_with_past_kv_inputs( return inputs +def get_dynamo_inputs( + config: AutoConfig, + device: torch.device, + batch_size: int, + seq_len: int, + past_seq_len: int, + max_seq_len: int, + use_fp16: bool = False, + use_buffer_share: bool = False, + engine: str = "pt", + return_dict: bool = False, + world_size: int = 1, +): + input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64) + attention_mask = torch.ones(batch_size, past_seq_len + seq_len, dtype=torch.int64) + # position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation + position_ids = get_position_ids(attention_mask, use_past_kv=(past_seq_len != 0)) + past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size) + + # Convert inputs to NumPy (for ORT) or send to device (for PyTorch) + input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device) + attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device) + position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device) + past_kv = ( + flatten_past_kv_inputs(past_kv) + if engine == "ort" + else list(map(lambda kv: (kv[0].to(device), kv[1].to(device)), past_kv)) + ) + + if not return_dict: + # For export + assert isinstance(past_kv, list) + return (input_ids, attention_mask, position_ids, past_kv) + + inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + } + if engine == "ort": + assert isinstance(past_kv, dict) + inputs.update(past_kv) + + if use_buffer_share: + inputs = enable_past_present_share_buffer(inputs, past_seq_len, max_seq_len) + + else: + assert isinstance(past_kv, list) + inputs["past_key_values"] = past_kv + + return inputs + # Inputs for Microsoft export from https://github.com/microsoft/Llama-2-Onnx def get_msft_sample_inputs( diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index eab55154b50b1..5a3b431ca9a4c 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -21,6 +21,7 @@ get_sample_inputs, get_sample_with_past_kv_inputs, verify_ort_inputs, + get_dynamo_inputs ) from llama_torch import setup_torch_model from transformers import AutoConfig @@ -41,8 +42,10 @@ def get_inputs(args: argparse.Namespace, config: AutoConfig): world_size = get_size() batch_size = 2 past_sequence_length, sequence_length, max_sequence_length = get_sequence_lengths(args, config) - - if args.merged: + + if args.dynamo: + inputs = get_dynamo_inputs(config, args.device, batch_size, sequence_length, return_dict=True) + elif args.merged: inputs = get_merged_sample_with_past_kv_inputs( config, args.device, @@ -166,6 +169,7 @@ def verify_parity( def get_args(argv: list[str]): parser = argparse.ArgumentParser() + parser.add_argument( "-m", "--model_name", @@ -236,6 +240,14 @@ def get_args(argv: list[str]): choices=["int4", "int8", "fp16", "fp32"], help="Precision of model", ) + + parser.add_argument( + "--dynamo", + action="store_true", + help="Use Dynamo model inputs for parity check", + ) + parser.set_defaults(dynamo=False) + parser.add_argument( "--cache_dir", diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_torch.py b/onnxruntime/python/tools/transformers/models/llama/llama_torch.py index 643b21ce61343..6fcf0a517258d 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_torch.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_torch.py @@ -29,6 +29,7 @@ def setup_torch_model(args, location, auth, torch_dtype=torch.float32, device=No ) l_config.use_cache = True l_config._attn_implementation = "eager" # "eager" uses LlamaAttention for attention layer + # l_config.num_hidden_layers = 1 llama = AutoModelForCausalLM.from_pretrained( location, use_auth_token=auth, From 0f0ef372bfcb843a5e84c8fac2cda180154e1337 Mon Sep 17 00:00:00 2001 From: Emmanuel Date: Wed, 14 Aug 2024 22:05:36 +0000 Subject: [PATCH 4/9] resolved changes --- .../models/llama/convert_to_onnx.py | 224 ++++++++---------- .../transformers/models/llama/llama_inputs.py | 52 ---- .../transformers/models/llama/llama_parity.py | 18 +- 3 files changed, 107 insertions(+), 187 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index e550c1ae7a73c..2d5992220175d 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -153,23 +153,6 @@ def run_dynamo_export( batch_size, sequence_length, past_sequence_length = 2, 8, 0 device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu") - # Export decoder_model.onnx, commented out for now to use dynamo_export - # input_ids, attn_mask, pos_ids = get_sample_inputs(l_config, device, batch_size, sequence_length) - # temp_dir = tempfile.TemporaryDirectory() - # temp_path = os.path.join(temp_dir.name, "temp.onnx") - # torch.onnx.dynamo_export( - # llama, input_ids, attn_mask, pos_ids, export_options=torch.onnx.ExportOptions(dynamic_shapes=True) - # ).save(temp_path) - - # # Check decoder_model.onnx and save all external data to one file - # onnx.checker.check_model(temp_path) - # onnx.shape_inference.infer_shapes_path(temp_path) - - # output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx") - # onnx_model = onnx.load_model(temp_path, load_external_data=True) - # save_onnx_model(onnx_model, output_path, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx.data") - # del onnx_model - # temp_dir.cleanup() temp_name = args.model_name.lower().replace("-", "").replace("_", "") @@ -911,109 +894,110 @@ def main(): decoder_merged_model_fp32_opt_path, ] - # # Run the optimizer script, runs the torch as well. Keeping this block commented makes sure only Dynamo export is used. - # logger.info("Optimizing models...") - # for orig_path, opt_path in zip(old_paths, new_paths): - # if os.path.exists(orig_path): - # optimize_export(args, l_config, input_path=orig_path, output_path=opt_path, world_size=world_size) - - # # Re-assign default FP32 model paths as their optimized versions - # decoder_model_fp32_path = decoder_model_fp32_opt_path - # decoder_with_past_model_fp32_path = decoder_with_past_model_fp32_opt_path - # decoder_merged_model_fp32_path = decoder_merged_model_fp32_opt_path - # old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path] - - # logger.info( - # f"The {args.model_name} ONNX model has been successfully optimized with the ORT transformer optimizer script!" - # ) - - # # Change precision of exported models from FP32 - # if args.precision == Precision.FLOAT16: - # new_paths = convert_to_float16(args, old_paths, rank) - - # elif args.precision == Precision.INT8: - # decoder_model_int8_path = os.path.join( - # args.output, f"rank_{rank}_{args.model_name}_decoder_model_int8.onnx" - # ) - # decoder_with_past_model_int8_path = os.path.join( - # args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int8.onnx" - # ) - # decoder_merged_model_int8_path = os.path.join( - # args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int8.onnx" - # ) - # new_paths = [decoder_model_int8_path, decoder_with_past_model_int8_path, decoder_merged_model_int8_path] - - # if args.quantization_method == "smooth_quant": - # if not args.no_merged: - # logger.error("SmoothQuant must be used on separately exported models") - # else: - # logger.info( - # f"Quantizing {decoder_model_fp32_path} and {decoder_with_past_model_fp32_path} to int8" - # ) - # smooth_quant(args, old_paths[0], old_paths[1], new_paths[0], new_paths[1]) - - # elif args.quantization_method == "quantize_dynamic": - # logger.warning( - # "The `quantize_dynamic` method is deprecated in favor of `smooth_quant` instead. Precision loss may be high with `quantize_dynamic`." - # ) - - # logger.info("Quantizing to int8...") - # for fp32_path, int8_path in zip(old_paths, new_paths): - # if os.path.exists(fp32_path): - # ort_quantization.quantize_dynamic( - # fp32_path, - # int8_path, - # op_types_to_quantize=( - # ["MatMul", "Gemm", "Gather"] - # if args.quantize_embedding_layer - # else ["MatMul", "Gemm"] - # ), - # per_channel=args.quantize_per_channel, - # reduce_range=args.quantize_reduce_range, - # use_external_data_format=True, - # extra_options={"MatMulConstBOnly": True}, - # ) - # logger.info( - # f"The ONNX model at {fp32_path} has been quantized to int8 and saved at {int8_path}!" - # ) - # remove_existing_model(decoder_model_fp32_path) - - # logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!") - - # else: - # raise Exception(f"Could not recognize {args.quantization_method} as a quantization method") - - # elif args.precision == Precision.INT4: - # if args.execution_provider != "cpu": - # old_paths = convert_to_float16(args, old_paths, rank) - - # decoder_model_int4_path = os.path.join( - # args.output, f"rank_{rank}_{args.model_name}_decoder_model_int4.onnx" - # ) - # decoder_with_past_model_int4_path = os.path.join( - # args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int4.onnx" - # ) - # decoder_merged_model_int4_path = os.path.join( - # args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int4.onnx" - # ) - # new_paths = [decoder_model_int4_path, decoder_with_past_model_int4_path, decoder_merged_model_int4_path] - - # for fp_path, int4_path in zip(old_paths, new_paths): - # if os.path.exists(fp_path): - # model = onnx.load_model(fp_path, load_external_data=True) - # quant = MatMul4BitsQuantizer( - # model=model, - # block_size=args.block_size, - # is_symmetric=True, - # accuracy_level=args.int4_accuracy_level, - # nodes_to_exclude=[], - # ) - # quant.process() - # quant.model.save_model_to_file(int4_path, use_external_data_format=True) - # del model - # del quant - # logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!") - # remove_existing_model(fp_path) + # Run the optimizer script, runs the torch as well. + if not args.use_dynamo_export: + logger.info("Optimizing models...") + for orig_path, opt_path in zip(old_paths, new_paths): + if os.path.exists(orig_path): + optimize_export(args, l_config, input_path=orig_path, output_path=opt_path, world_size=world_size) + + # Re-assign default FP32 model paths as their optimized versions + decoder_model_fp32_path = decoder_model_fp32_opt_path + decoder_with_past_model_fp32_path = decoder_with_past_model_fp32_opt_path + decoder_merged_model_fp32_path = decoder_merged_model_fp32_opt_path + old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path] + + logger.info( + f"The {args.model_name} ONNX model has been successfully optimized with the ORT transformer optimizer script!" + ) + + # Change precision of exported models from FP32 + if args.precision == Precision.FLOAT16: + new_paths = convert_to_float16(args, old_paths, rank) + + elif args.precision == Precision.INT8: + decoder_model_int8_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_model_int8.onnx" + ) + decoder_with_past_model_int8_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int8.onnx" + ) + decoder_merged_model_int8_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int8.onnx" + ) + new_paths = [decoder_model_int8_path, decoder_with_past_model_int8_path, decoder_merged_model_int8_path] + + if args.quantization_method == "smooth_quant": + if not args.no_merged: + logger.error("SmoothQuant must be used on separately exported models") + else: + logger.info( + f"Quantizing {decoder_model_fp32_path} and {decoder_with_past_model_fp32_path} to int8" + ) + smooth_quant(args, old_paths[0], old_paths[1], new_paths[0], new_paths[1]) + + elif args.quantization_method == "quantize_dynamic": + logger.warning( + "The `quantize_dynamic` method is deprecated in favor of `smooth_quant` instead. Precision loss may be high with `quantize_dynamic`." + ) + + logger.info("Quantizing to int8...") + for fp32_path, int8_path in zip(old_paths, new_paths): + if os.path.exists(fp32_path): + ort_quantization.quantize_dynamic( + fp32_path, + int8_path, + op_types_to_quantize=( + ["MatMul", "Gemm", "Gather"] + if args.quantize_embedding_layer + else ["MatMul", "Gemm"] + ), + per_channel=args.quantize_per_channel, + reduce_range=args.quantize_reduce_range, + use_external_data_format=True, + extra_options={"MatMulConstBOnly": True}, + ) + logger.info( + f"The ONNX model at {fp32_path} has been quantized to int8 and saved at {int8_path}!" + ) + remove_existing_model(decoder_model_fp32_path) + + logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!") + + else: + raise Exception(f"Could not recognize {args.quantization_method} as a quantization method") + + elif args.precision == Precision.INT4: + if args.execution_provider != "cpu": + old_paths = convert_to_float16(args, old_paths, rank) + + decoder_model_int4_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_model_int4.onnx" + ) + decoder_with_past_model_int4_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int4.onnx" + ) + decoder_merged_model_int4_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int4.onnx" + ) + new_paths = [decoder_model_int4_path, decoder_with_past_model_int4_path, decoder_merged_model_int4_path] + + for fp_path, int4_path in zip(old_paths, new_paths): + if os.path.exists(fp_path): + model = onnx.load_model(fp_path, load_external_data=True) + quant = MatMul4BitsQuantizer( + model=model, + block_size=args.block_size, + is_symmetric=True, + accuracy_level=args.int4_accuracy_level, + nodes_to_exclude=[], + ) + quant.process() + quant.model.save_model_to_file(int4_path, use_external_data_format=True) + del model + del quant + logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!") + remove_existing_model(fp_path) barrier() logger.info("Verifying parity on all ONNX models created") diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py index 1cb3f3587bc29..ffe346a1b46fa 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py @@ -172,58 +172,6 @@ def get_merged_sample_with_past_kv_inputs( return inputs -def get_dynamo_inputs( - config: AutoConfig, - device: torch.device, - batch_size: int, - seq_len: int, - past_seq_len: int, - max_seq_len: int, - use_fp16: bool = False, - use_buffer_share: bool = False, - engine: str = "pt", - return_dict: bool = False, - world_size: int = 1, -): - input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64) - attention_mask = torch.ones(batch_size, past_seq_len + seq_len, dtype=torch.int64) - # position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation - position_ids = get_position_ids(attention_mask, use_past_kv=(past_seq_len != 0)) - past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size) - - # Convert inputs to NumPy (for ORT) or send to device (for PyTorch) - input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device) - attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device) - position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device) - past_kv = ( - flatten_past_kv_inputs(past_kv) - if engine == "ort" - else list(map(lambda kv: (kv[0].to(device), kv[1].to(device)), past_kv)) - ) - - if not return_dict: - # For export - assert isinstance(past_kv, list) - return (input_ids, attention_mask, position_ids, past_kv) - - inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "position_ids": position_ids, - } - if engine == "ort": - assert isinstance(past_kv, dict) - inputs.update(past_kv) - - if use_buffer_share: - inputs = enable_past_present_share_buffer(inputs, past_seq_len, max_seq_len) - - else: - assert isinstance(past_kv, list) - inputs["past_key_values"] = past_kv - - return inputs - # Inputs for Microsoft export from https://github.com/microsoft/Llama-2-Onnx def get_msft_sample_inputs( diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index 5a3b431ca9a4c..170ed15a07368 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -21,7 +21,6 @@ get_sample_inputs, get_sample_with_past_kv_inputs, verify_ort_inputs, - get_dynamo_inputs ) from llama_torch import setup_torch_model from transformers import AutoConfig @@ -42,10 +41,8 @@ def get_inputs(args: argparse.Namespace, config: AutoConfig): world_size = get_size() batch_size = 2 past_sequence_length, sequence_length, max_sequence_length = get_sequence_lengths(args, config) - - if args.dynamo: - inputs = get_dynamo_inputs(config, args.device, batch_size, sequence_length, return_dict=True) - elif args.merged: + + if args.merged: inputs = get_merged_sample_with_past_kv_inputs( config, args.device, @@ -169,7 +166,6 @@ def verify_parity( def get_args(argv: list[str]): parser = argparse.ArgumentParser() - parser.add_argument( "-m", "--model_name", @@ -240,14 +236,6 @@ def get_args(argv: list[str]): choices=["int4", "int8", "fp16", "fp32"], help="Precision of model", ) - - parser.add_argument( - "--dynamo", - action="store_true", - help="Use Dynamo model inputs for parity check", - ) - parser.set_defaults(dynamo=False) - parser.add_argument( "--cache_dir", @@ -318,4 +306,4 @@ def main(argv: list[str] = []): # noqa: B006 seed = 2 np.random.seed(seed) torch.manual_seed(seed) - main() + main() \ No newline at end of file From d10f07921fc90ab0d0c0f19cd9bdf222659a6bdf Mon Sep 17 00:00:00 2001 From: Emmanuel Date: Wed, 14 Aug 2024 22:16:52 +0000 Subject: [PATCH 5/9] resolved changes --- .../python/tools/transformers/models/llama/llama_torch.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_torch.py b/onnxruntime/python/tools/transformers/models/llama/llama_torch.py index 6fcf0a517258d..643b21ce61343 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_torch.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_torch.py @@ -29,7 +29,6 @@ def setup_torch_model(args, location, auth, torch_dtype=torch.float32, device=No ) l_config.use_cache = True l_config._attn_implementation = "eager" # "eager" uses LlamaAttention for attention layer - # l_config.num_hidden_layers = 1 llama = AutoModelForCausalLM.from_pretrained( location, use_auth_token=auth, From b703e0b37a7db113c889a149db757927b0733c3b Mon Sep 17 00:00:00 2001 From: Emmanuel Date: Wed, 14 Aug 2024 22:28:15 +0000 Subject: [PATCH 6/9] resolved changes --- .../python/tools/transformers/models/llama/llama_parity.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index 170ed15a07368..76be4031d34ce 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -306,4 +306,5 @@ def main(argv: list[str] = []): # noqa: B006 seed = 2 np.random.seed(seed) torch.manual_seed(seed) - main() \ No newline at end of file + main() + \ No newline at end of file From 174888c221d111bd61cdfd83d1ca75a42e3b41d7 Mon Sep 17 00:00:00 2001 From: Emmanuel Date: Thu, 15 Aug 2024 00:52:43 +0000 Subject: [PATCH 7/9] adding extra changes --- .../models/llama/convert_to_onnx.py | 186 +++++++++--------- 1 file changed, 95 insertions(+), 91 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 2d5992220175d..7118cfa57422b 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -896,109 +896,113 @@ def main(): # Run the optimizer script, runs the torch as well. if not args.use_dynamo_export: - logger.info("Optimizing models...") - for orig_path, opt_path in zip(old_paths, new_paths): - if os.path.exists(orig_path): - optimize_export(args, l_config, input_path=orig_path, output_path=opt_path, world_size=world_size) - - # Re-assign default FP32 model paths as their optimized versions - decoder_model_fp32_path = decoder_model_fp32_opt_path - decoder_with_past_model_fp32_path = decoder_with_past_model_fp32_opt_path - decoder_merged_model_fp32_path = decoder_merged_model_fp32_opt_path - old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path] - - logger.info( - f"The {args.model_name} ONNX model has been successfully optimized with the ORT transformer optimizer script!" + continue + # Rest of optimizer code for TorchScript exported model + logger.info("Optimizing models...") + for orig_path, opt_path in zip(old_paths, new_paths): + if os.path.exists(orig_path): + optimize_export(args, l_config, input_path=orig_path, output_path=opt_path, world_size=world_size) + + # Re-assign default FP32 model paths as their optimized versions + decoder_model_fp32_path = decoder_model_fp32_opt_path + decoder_with_past_model_fp32_path = decoder_with_past_model_fp32_opt_path + decoder_merged_model_fp32_path = decoder_merged_model_fp32_opt_path + old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path] + + logger.info( + f"The {args.model_name} ONNX model has been successfully optimized with the ORT transformer optimizer script!" + ) + + # Change precision of exported models from FP32 + if args.precision == Precision.FLOAT16: + new_paths = convert_to_float16(args, old_paths, rank) + + elif args.precision == Precision.INT8: + decoder_model_int8_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_model_int8.onnx" + ) + decoder_with_past_model_int8_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int8.onnx" + ) + decoder_merged_model_int8_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int8.onnx" ) + new_paths = [decoder_model_int8_path, decoder_with_past_model_int8_path, decoder_merged_model_int8_path] - # Change precision of exported models from FP32 - if args.precision == Precision.FLOAT16: - new_paths = convert_to_float16(args, old_paths, rank) + if args.quantization_method == "smooth_quant": + if not args.no_merged: + logger.error("SmoothQuant must be used on separately exported models") + else: + logger.info( + f"Quantizing {decoder_model_fp32_path} and {decoder_with_past_model_fp32_path} to int8" + ) + smooth_quant(args, old_paths[0], old_paths[1], new_paths[0], new_paths[1]) - elif args.precision == Precision.INT8: - decoder_model_int8_path = os.path.join( - args.output, f"rank_{rank}_{args.model_name}_decoder_model_int8.onnx" - ) - decoder_with_past_model_int8_path = os.path.join( - args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int8.onnx" - ) - decoder_merged_model_int8_path = os.path.join( - args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int8.onnx" + elif args.quantization_method == "quantize_dynamic": + logger.warning( + "The `quantize_dynamic` method is deprecated in favor of `smooth_quant` instead. Precision loss may be high with `quantize_dynamic`." ) - new_paths = [decoder_model_int8_path, decoder_with_past_model_int8_path, decoder_merged_model_int8_path] - if args.quantization_method == "smooth_quant": - if not args.no_merged: - logger.error("SmoothQuant must be used on separately exported models") - else: + logger.info("Quantizing to int8...") + for fp32_path, int8_path in zip(old_paths, new_paths): + if os.path.exists(fp32_path): + ort_quantization.quantize_dynamic( + fp32_path, + int8_path, + op_types_to_quantize=( + ["MatMul", "Gemm", "Gather"] + if args.quantize_embedding_layer + else ["MatMul", "Gemm"] + ), + per_channel=args.quantize_per_channel, + reduce_range=args.quantize_reduce_range, + use_external_data_format=True, + extra_options={"MatMulConstBOnly": True}, + ) logger.info( - f"Quantizing {decoder_model_fp32_path} and {decoder_with_past_model_fp32_path} to int8" + f"The ONNX model at {fp32_path} has been quantized to int8 and saved at {int8_path}!" ) - smooth_quant(args, old_paths[0], old_paths[1], new_paths[0], new_paths[1]) + remove_existing_model(decoder_model_fp32_path) - elif args.quantization_method == "quantize_dynamic": - logger.warning( - "The `quantize_dynamic` method is deprecated in favor of `smooth_quant` instead. Precision loss may be high with `quantize_dynamic`." - ) - - logger.info("Quantizing to int8...") - for fp32_path, int8_path in zip(old_paths, new_paths): - if os.path.exists(fp32_path): - ort_quantization.quantize_dynamic( - fp32_path, - int8_path, - op_types_to_quantize=( - ["MatMul", "Gemm", "Gather"] - if args.quantize_embedding_layer - else ["MatMul", "Gemm"] - ), - per_channel=args.quantize_per_channel, - reduce_range=args.quantize_reduce_range, - use_external_data_format=True, - extra_options={"MatMulConstBOnly": True}, - ) - logger.info( - f"The ONNX model at {fp32_path} has been quantized to int8 and saved at {int8_path}!" - ) - remove_existing_model(decoder_model_fp32_path) - - logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!") + logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!") - else: - raise Exception(f"Could not recognize {args.quantization_method} as a quantization method") + else: + raise Exception(f"Could not recognize {args.quantization_method} as a quantization method") - elif args.precision == Precision.INT4: - if args.execution_provider != "cpu": - old_paths = convert_to_float16(args, old_paths, rank) + elif args.precision == Precision.INT4: + if args.execution_provider != "cpu": + old_paths = convert_to_float16(args, old_paths, rank) - decoder_model_int4_path = os.path.join( - args.output, f"rank_{rank}_{args.model_name}_decoder_model_int4.onnx" - ) - decoder_with_past_model_int4_path = os.path.join( - args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int4.onnx" - ) - decoder_merged_model_int4_path = os.path.join( - args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int4.onnx" - ) - new_paths = [decoder_model_int4_path, decoder_with_past_model_int4_path, decoder_merged_model_int4_path] - - for fp_path, int4_path in zip(old_paths, new_paths): - if os.path.exists(fp_path): - model = onnx.load_model(fp_path, load_external_data=True) - quant = MatMul4BitsQuantizer( - model=model, - block_size=args.block_size, - is_symmetric=True, - accuracy_level=args.int4_accuracy_level, - nodes_to_exclude=[], - ) - quant.process() - quant.model.save_model_to_file(int4_path, use_external_data_format=True) - del model - del quant - logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!") - remove_existing_model(fp_path) + decoder_model_int4_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_model_int4.onnx" + ) + decoder_with_past_model_int4_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int4.onnx" + ) + decoder_merged_model_int4_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int4.onnx" + ) + new_paths = [decoder_model_int4_path, decoder_with_past_model_int4_path, decoder_merged_model_int4_path] + + for fp_path, int4_path in zip(old_paths, new_paths): + if os.path.exists(fp_path): + model = onnx.load_model(fp_path, load_external_data=True) + quant = MatMul4BitsQuantizer( + model=model, + block_size=args.block_size, + is_symmetric=True, + accuracy_level=args.int4_accuracy_level, + nodes_to_exclude=[], + ) + quant.process() + quant.model.save_model_to_file(int4_path, use_external_data_format=True) + del model + del quant + logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!") + remove_existing_model(fp_path) barrier() + if args.use_dynamo_export: + return logger.info("Verifying parity on all ONNX models created") From f1472d84ac42438e554bb6facfd16958701873c0 Mon Sep 17 00:00:00 2001 From: Emmanuel Date: Fri, 16 Aug 2024 00:33:52 +0000 Subject: [PATCH 8/9] final pr changes --- .../models/llama/convert_to_onnx.py | 41 +++---------------- .../transformers/models/llama/llama_inputs.py | 1 - .../transformers/models/llama/llama_parity.py | 1 - 3 files changed, 5 insertions(+), 38 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 7118cfa57422b..46b5ea745d284 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -11,11 +11,11 @@ import shutil import subprocess import sys +import tempfile from itertools import chain import onnx import torch -import tempfile from benchmark_helper import Precision, prepare_environment, setup_logger from convert_generation import replace_mha_with_gqa from dist_settings import barrier, get_rank, get_size, init_dist @@ -114,34 +114,6 @@ def save_onnx_model(onnx_model: onnx.ModelProto, output_path: str, data_path: st ) -# Notes: -# 1) Dynamo export will not work automatically until this issue is resolved: https://github.com/microsoft/onnxscript/issues/493 -# -# 2) Dynamo export will run manually if you set the ONNX file path to the same path that you use to save the model after export. -# In other words, the value of `temp_path` should be set as the ONNX file path. You can open the issue in your browser to find -# the location in ONNX Script where you have to make this change. -# -# Once the issue is resolved, we hope to modify the code below as follows for each export. -# -# Before: -# temp_dir = args.output -# temp_path = os.path.join(temp_dir, "temp.onnx") -# ... -# ... -# ... -# del onnx_model -# os.system(f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}") -# -# -# After: -# temp_dir = tempfile.TemporaryDirectory() -# temp_path = os.path.join(temp_dir.name, "temp.onnx") -# ... -# ... -# ... -# del onnx_model -# temp_dir.cleanup() -# def run_dynamo_export( args: argparse.Namespace, l_config: AutoConfig, llama: AutoModelForCausalLM, rank: int = 0, world_size: int = 1 ): @@ -153,12 +125,9 @@ def run_dynamo_export( batch_size, sequence_length, past_sequence_length = 2, 8, 0 device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu") - - temp_name = args.model_name.lower().replace("-", "").replace("_", "") max_sequence_length = 16384 if "codellama" in temp_name else 4096 if "llama2" in temp_name else 2048 - # Export decoder_with_past_model.onnx input_ids, attn_mask, pos_ids, past_kv = get_merged_sample_with_past_kv_inputs( l_config, @@ -894,10 +863,10 @@ def main(): decoder_merged_model_fp32_opt_path, ] - # Run the optimizer script, runs the torch as well. - if not args.use_dynamo_export: + if args.use_dynamo_export: continue - # Rest of optimizer code for TorchScript exported model + + # Run the optimizer script. logger.info("Optimizing models...") for orig_path, opt_path in zip(old_paths, new_paths): if os.path.exists(orig_path): @@ -1053,4 +1022,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py index ffe346a1b46fa..39f0588436d2e 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py @@ -28,7 +28,6 @@ def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool): # input_ids: (batch_size, sequence_length) # attention_mask: (batch_size, sequence_length) # position_ids: (batch_size, sequence_length) - def get_sample_inputs( config: AutoConfig, device: torch.device, diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index 76be4031d34ce..eab55154b50b1 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -307,4 +307,3 @@ def main(argv: list[str] = []): # noqa: B006 np.random.seed(seed) torch.manual_seed(seed) main() - \ No newline at end of file From 644047f3992ea9fec556a02aa5c7e2d69c77a43d Mon Sep 17 00:00:00 2001 From: Emmanuel Date: Fri, 16 Aug 2024 00:37:49 +0000 Subject: [PATCH 9/9] final pr --- .../python/tools/transformers/models/llama/convert_to_onnx.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 46b5ea745d284..c94e5a173a2df 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -970,6 +970,7 @@ def main(): logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!") remove_existing_model(fp_path) barrier() + if args.use_dynamo_export: return