diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 615e25f65d39f..aebaa0b916876 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -878,7 +878,7 @@ def main(): # Export to ONNX if missing_separate_exports or missing_merged_export: - if args.use_dynamo_export and missing_merged_export: + if args.use_dynamo_export: logger.warning("Please ensure you have installed PyTorch, ONNX, and ONNX Script as follows.") logger.warning("Step 1 - PyTorch nightly: https://pytorch.org/get-started/locally/") logger.warning("Step 2 - ONNX weekly: https://pypi.org/project/onnx-weekly/") @@ -911,109 +911,109 @@ def main(): decoder_merged_model_fp32_opt_path, ] - # Run the optimizer script - logger.info("Optimizing models...") - for orig_path, opt_path in zip(old_paths, new_paths): - if os.path.exists(orig_path): - optimize_export(args, l_config, input_path=orig_path, output_path=opt_path, world_size=world_size) - - # Re-assign default FP32 model paths as their optimized versions - decoder_model_fp32_path = decoder_model_fp32_opt_path - decoder_with_past_model_fp32_path = decoder_with_past_model_fp32_opt_path - decoder_merged_model_fp32_path = decoder_merged_model_fp32_opt_path - old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path] - - logger.info( - f"The {args.model_name} ONNX model has been successfully optimized with the ORT transformer optimizer script!" - ) - - # Change precision of exported models from FP32 - if args.precision == Precision.FLOAT16: - new_paths = convert_to_float16(args, old_paths, rank) - - elif args.precision == Precision.INT8: - decoder_model_int8_path = os.path.join( - args.output, f"rank_{rank}_{args.model_name}_decoder_model_int8.onnx" - ) - decoder_with_past_model_int8_path = os.path.join( - args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int8.onnx" - ) - decoder_merged_model_int8_path = os.path.join( - args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int8.onnx" - ) - new_paths = [decoder_model_int8_path, decoder_with_past_model_int8_path, decoder_merged_model_int8_path] - - if args.quantization_method == "smooth_quant": - if not args.no_merged: - logger.error("SmoothQuant must be used on separately exported models") - else: - logger.info( - f"Quantizing {decoder_model_fp32_path} and {decoder_with_past_model_fp32_path} to int8" - ) - smooth_quant(args, old_paths[0], old_paths[1], new_paths[0], new_paths[1]) - - elif args.quantization_method == "quantize_dynamic": - logger.warning( - "The `quantize_dynamic` method is deprecated in favor of `smooth_quant` instead. Precision loss may be high with `quantize_dynamic`." - ) - - logger.info("Quantizing to int8...") - for fp32_path, int8_path in zip(old_paths, new_paths): - if os.path.exists(fp32_path): - ort_quantization.quantize_dynamic( - fp32_path, - int8_path, - op_types_to_quantize=( - ["MatMul", "Gemm", "Gather"] - if args.quantize_embedding_layer - else ["MatMul", "Gemm"] - ), - per_channel=args.quantize_per_channel, - reduce_range=args.quantize_reduce_range, - use_external_data_format=True, - extra_options={"MatMulConstBOnly": True}, - ) - logger.info( - f"The ONNX model at {fp32_path} has been quantized to int8 and saved at {int8_path}!" - ) - remove_existing_model(decoder_model_fp32_path) - - logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!") - - else: - raise Exception(f"Could not recognize {args.quantization_method} as a quantization method") - - elif args.precision == Precision.INT4: - if args.execution_provider != "cpu": - old_paths = convert_to_float16(args, old_paths, rank) - - decoder_model_int4_path = os.path.join( - args.output, f"rank_{rank}_{args.model_name}_decoder_model_int4.onnx" - ) - decoder_with_past_model_int4_path = os.path.join( - args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int4.onnx" - ) - decoder_merged_model_int4_path = os.path.join( - args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int4.onnx" - ) - new_paths = [decoder_model_int4_path, decoder_with_past_model_int4_path, decoder_merged_model_int4_path] - - for fp_path, int4_path in zip(old_paths, new_paths): - if os.path.exists(fp_path): - model = onnx.load_model(fp_path, load_external_data=True) - quant = MatMul4BitsQuantizer( - model=model, - block_size=args.block_size, - is_symmetric=True, - accuracy_level=args.int4_accuracy_level, - nodes_to_exclude=[], - ) - quant.process() - quant.model.save_model_to_file(int4_path, use_external_data_format=True) - del model - del quant - logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!") - remove_existing_model(fp_path) + # # Run the optimizer script, runs the torch as well + # logger.info("Optimizing models...") + # for orig_path, opt_path in zip(old_paths, new_paths): + # if os.path.exists(orig_path): + # optimize_export(args, l_config, input_path=orig_path, output_path=opt_path, world_size=world_size) + + # # Re-assign default FP32 model paths as their optimized versions + # decoder_model_fp32_path = decoder_model_fp32_opt_path + # decoder_with_past_model_fp32_path = decoder_with_past_model_fp32_opt_path + # decoder_merged_model_fp32_path = decoder_merged_model_fp32_opt_path + # old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path] + + # logger.info( + # f"The {args.model_name} ONNX model has been successfully optimized with the ORT transformer optimizer script!" + # ) + + # # Change precision of exported models from FP32 + # if args.precision == Precision.FLOAT16: + # new_paths = convert_to_float16(args, old_paths, rank) + + # elif args.precision == Precision.INT8: + # decoder_model_int8_path = os.path.join( + # args.output, f"rank_{rank}_{args.model_name}_decoder_model_int8.onnx" + # ) + # decoder_with_past_model_int8_path = os.path.join( + # args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int8.onnx" + # ) + # decoder_merged_model_int8_path = os.path.join( + # args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int8.onnx" + # ) + # new_paths = [decoder_model_int8_path, decoder_with_past_model_int8_path, decoder_merged_model_int8_path] + + # if args.quantization_method == "smooth_quant": + # if not args.no_merged: + # logger.error("SmoothQuant must be used on separately exported models") + # else: + # logger.info( + # f"Quantizing {decoder_model_fp32_path} and {decoder_with_past_model_fp32_path} to int8" + # ) + # smooth_quant(args, old_paths[0], old_paths[1], new_paths[0], new_paths[1]) + + # elif args.quantization_method == "quantize_dynamic": + # logger.warning( + # "The `quantize_dynamic` method is deprecated in favor of `smooth_quant` instead. Precision loss may be high with `quantize_dynamic`." + # ) + + # logger.info("Quantizing to int8...") + # for fp32_path, int8_path in zip(old_paths, new_paths): + # if os.path.exists(fp32_path): + # ort_quantization.quantize_dynamic( + # fp32_path, + # int8_path, + # op_types_to_quantize=( + # ["MatMul", "Gemm", "Gather"] + # if args.quantize_embedding_layer + # else ["MatMul", "Gemm"] + # ), + # per_channel=args.quantize_per_channel, + # reduce_range=args.quantize_reduce_range, + # use_external_data_format=True, + # extra_options={"MatMulConstBOnly": True}, + # ) + # logger.info( + # f"The ONNX model at {fp32_path} has been quantized to int8 and saved at {int8_path}!" + # ) + # remove_existing_model(decoder_model_fp32_path) + + # logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!") + + # else: + # raise Exception(f"Could not recognize {args.quantization_method} as a quantization method") + + # elif args.precision == Precision.INT4: + # if args.execution_provider != "cpu": + # old_paths = convert_to_float16(args, old_paths, rank) + + # decoder_model_int4_path = os.path.join( + # args.output, f"rank_{rank}_{args.model_name}_decoder_model_int4.onnx" + # ) + # decoder_with_past_model_int4_path = os.path.join( + # args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int4.onnx" + # ) + # decoder_merged_model_int4_path = os.path.join( + # args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int4.onnx" + # ) + # new_paths = [decoder_model_int4_path, decoder_with_past_model_int4_path, decoder_merged_model_int4_path] + + # for fp_path, int4_path in zip(old_paths, new_paths): + # if os.path.exists(fp_path): + # model = onnx.load_model(fp_path, load_external_data=True) + # quant = MatMul4BitsQuantizer( + # model=model, + # block_size=args.block_size, + # is_symmetric=True, + # accuracy_level=args.int4_accuracy_level, + # nodes_to_exclude=[], + # ) + # quant.process() + # quant.model.save_model_to_file(int4_path, use_external_data_format=True) + # del model + # del quant + # logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!") + # remove_existing_model(fp_path) barrier() logger.info("Verifying parity on all ONNX models created")