Skip to content

Commit

Permalink
changes to dynamo
Browse files Browse the repository at this point in the history
  • Loading branch information
kobby-kobbs committed Jul 23, 2024
1 parent 65423a2 commit 16c94c8
Showing 1 changed file with 104 additions and 104 deletions.
208 changes: 104 additions & 104 deletions onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,7 @@ def main():

# Export to ONNX
if missing_separate_exports or missing_merged_export:
if args.use_dynamo_export and missing_merged_export:
if args.use_dynamo_export:
logger.warning("Please ensure you have installed PyTorch, ONNX, and ONNX Script as follows.")
logger.warning("Step 1 - PyTorch nightly: https://pytorch.org/get-started/locally/")
logger.warning("Step 2 - ONNX weekly: https://pypi.org/project/onnx-weekly/")
Expand Down Expand Up @@ -911,109 +911,109 @@ def main():
decoder_merged_model_fp32_opt_path,
]

# Run the optimizer script
logger.info("Optimizing models...")
for orig_path, opt_path in zip(old_paths, new_paths):
if os.path.exists(orig_path):
optimize_export(args, l_config, input_path=orig_path, output_path=opt_path, world_size=world_size)

# Re-assign default FP32 model paths as their optimized versions
decoder_model_fp32_path = decoder_model_fp32_opt_path
decoder_with_past_model_fp32_path = decoder_with_past_model_fp32_opt_path
decoder_merged_model_fp32_path = decoder_merged_model_fp32_opt_path
old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path]

logger.info(
f"The {args.model_name} ONNX model has been successfully optimized with the ORT transformer optimizer script!"
)

# Change precision of exported models from FP32
if args.precision == Precision.FLOAT16:
new_paths = convert_to_float16(args, old_paths, rank)

elif args.precision == Precision.INT8:
decoder_model_int8_path = os.path.join(
args.output, f"rank_{rank}_{args.model_name}_decoder_model_int8.onnx"
)
decoder_with_past_model_int8_path = os.path.join(
args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int8.onnx"
)
decoder_merged_model_int8_path = os.path.join(
args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int8.onnx"
)
new_paths = [decoder_model_int8_path, decoder_with_past_model_int8_path, decoder_merged_model_int8_path]

if args.quantization_method == "smooth_quant":
if not args.no_merged:
logger.error("SmoothQuant must be used on separately exported models")
else:
logger.info(
f"Quantizing {decoder_model_fp32_path} and {decoder_with_past_model_fp32_path} to int8"
)
smooth_quant(args, old_paths[0], old_paths[1], new_paths[0], new_paths[1])

elif args.quantization_method == "quantize_dynamic":
logger.warning(
"The `quantize_dynamic` method is deprecated in favor of `smooth_quant` instead. Precision loss may be high with `quantize_dynamic`."
)

logger.info("Quantizing to int8...")
for fp32_path, int8_path in zip(old_paths, new_paths):
if os.path.exists(fp32_path):
ort_quantization.quantize_dynamic(
fp32_path,
int8_path,
op_types_to_quantize=(
["MatMul", "Gemm", "Gather"]
if args.quantize_embedding_layer
else ["MatMul", "Gemm"]
),
per_channel=args.quantize_per_channel,
reduce_range=args.quantize_reduce_range,
use_external_data_format=True,
extra_options={"MatMulConstBOnly": True},
)
logger.info(
f"The ONNX model at {fp32_path} has been quantized to int8 and saved at {int8_path}!"
)
remove_existing_model(decoder_model_fp32_path)

logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!")

else:
raise Exception(f"Could not recognize {args.quantization_method} as a quantization method")

elif args.precision == Precision.INT4:
if args.execution_provider != "cpu":
old_paths = convert_to_float16(args, old_paths, rank)

decoder_model_int4_path = os.path.join(
args.output, f"rank_{rank}_{args.model_name}_decoder_model_int4.onnx"
)
decoder_with_past_model_int4_path = os.path.join(
args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int4.onnx"
)
decoder_merged_model_int4_path = os.path.join(
args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int4.onnx"
)
new_paths = [decoder_model_int4_path, decoder_with_past_model_int4_path, decoder_merged_model_int4_path]

for fp_path, int4_path in zip(old_paths, new_paths):
if os.path.exists(fp_path):
model = onnx.load_model(fp_path, load_external_data=True)
quant = MatMul4BitsQuantizer(
model=model,
block_size=args.block_size,
is_symmetric=True,
accuracy_level=args.int4_accuracy_level,
nodes_to_exclude=[],
)
quant.process()
quant.model.save_model_to_file(int4_path, use_external_data_format=True)
del model
del quant
logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!")
remove_existing_model(fp_path)
# # Run the optimizer script, runs the torch as well
# logger.info("Optimizing models...")
# for orig_path, opt_path in zip(old_paths, new_paths):
# if os.path.exists(orig_path):
# optimize_export(args, l_config, input_path=orig_path, output_path=opt_path, world_size=world_size)

# # Re-assign default FP32 model paths as their optimized versions
# decoder_model_fp32_path = decoder_model_fp32_opt_path
# decoder_with_past_model_fp32_path = decoder_with_past_model_fp32_opt_path
# decoder_merged_model_fp32_path = decoder_merged_model_fp32_opt_path
# old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path]

# logger.info(
# f"The {args.model_name} ONNX model has been successfully optimized with the ORT transformer optimizer script!"
# )

# # Change precision of exported models from FP32
# if args.precision == Precision.FLOAT16:
# new_paths = convert_to_float16(args, old_paths, rank)

# elif args.precision == Precision.INT8:
# decoder_model_int8_path = os.path.join(
# args.output, f"rank_{rank}_{args.model_name}_decoder_model_int8.onnx"
# )
# decoder_with_past_model_int8_path = os.path.join(
# args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int8.onnx"
# )
# decoder_merged_model_int8_path = os.path.join(
# args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int8.onnx"
# )
# new_paths = [decoder_model_int8_path, decoder_with_past_model_int8_path, decoder_merged_model_int8_path]

# if args.quantization_method == "smooth_quant":
# if not args.no_merged:
# logger.error("SmoothQuant must be used on separately exported models")
# else:
# logger.info(
# f"Quantizing {decoder_model_fp32_path} and {decoder_with_past_model_fp32_path} to int8"
# )
# smooth_quant(args, old_paths[0], old_paths[1], new_paths[0], new_paths[1])

# elif args.quantization_method == "quantize_dynamic":
# logger.warning(
# "The `quantize_dynamic` method is deprecated in favor of `smooth_quant` instead. Precision loss may be high with `quantize_dynamic`."
# )

# logger.info("Quantizing to int8...")
# for fp32_path, int8_path in zip(old_paths, new_paths):
# if os.path.exists(fp32_path):
# ort_quantization.quantize_dynamic(
# fp32_path,
# int8_path,
# op_types_to_quantize=(
# ["MatMul", "Gemm", "Gather"]
# if args.quantize_embedding_layer
# else ["MatMul", "Gemm"]
# ),
# per_channel=args.quantize_per_channel,
# reduce_range=args.quantize_reduce_range,
# use_external_data_format=True,
# extra_options={"MatMulConstBOnly": True},
# )
# logger.info(
# f"The ONNX model at {fp32_path} has been quantized to int8 and saved at {int8_path}!"
# )
# remove_existing_model(decoder_model_fp32_path)

# logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!")

# else:
# raise Exception(f"Could not recognize {args.quantization_method} as a quantization method")

# elif args.precision == Precision.INT4:
# if args.execution_provider != "cpu":
# old_paths = convert_to_float16(args, old_paths, rank)

# decoder_model_int4_path = os.path.join(
# args.output, f"rank_{rank}_{args.model_name}_decoder_model_int4.onnx"
# )
# decoder_with_past_model_int4_path = os.path.join(
# args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int4.onnx"
# )
# decoder_merged_model_int4_path = os.path.join(
# args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int4.onnx"
# )
# new_paths = [decoder_model_int4_path, decoder_with_past_model_int4_path, decoder_merged_model_int4_path]

# for fp_path, int4_path in zip(old_paths, new_paths):
# if os.path.exists(fp_path):
# model = onnx.load_model(fp_path, load_external_data=True)
# quant = MatMul4BitsQuantizer(
# model=model,
# block_size=args.block_size,
# is_symmetric=True,
# accuracy_level=args.int4_accuracy_level,
# nodes_to_exclude=[],
# )
# quant.process()
# quant.model.save_model_to_file(int4_path, use_external_data_format=True)
# del model
# del quant
# logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!")
# remove_existing_model(fp_path)
barrier()

logger.info("Verifying parity on all ONNX models created")
Expand Down

0 comments on commit 16c94c8

Please sign in to comment.