Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabled Dynamo exporter #21713

Merged
merged 9 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 113 additions & 120 deletions onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import onnx
import torch
import tempfile
from benchmark_helper import Precision, prepare_environment, setup_logger
from convert_generation import replace_mha_with_gqa
from dist_settings import barrier, get_rank, get_size, init_dist
Expand Down Expand Up @@ -149,35 +150,28 @@ def run_dynamo_export(
config.capture_scalar_outputs = True

# Dummy values for export
batch_size, sequence_length = 2, 8
device = torch.device("cpu")
batch_size, sequence_length, past_sequence_length = 2, 8, 0
device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu")

# Export decoder_model.onnx
input_ids, attn_mask, pos_ids = get_sample_inputs(l_config, device, batch_size, sequence_length)
temp_dir = args.output # tempfile.TemporaryDirectory()
temp_path = os.path.join(temp_dir, "temp.onnx") # os.path.join(temp_dir.name, "temp.onnx")
torch.onnx.dynamo_export(
llama, input_ids, attn_mask, pos_ids, export_options=torch.onnx.ExportOptions(dynamic_shapes=True)
).save(temp_path)

# Check decoder_model.onnx and save all external data to one file
onnx.checker.check_model(temp_path)
onnx.shape_inference.infer_shapes_path(temp_path)

output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx")
onnx_model = onnx.load_model(temp_path, load_external_data=True)
save_onnx_model(onnx_model, output_path, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx.data")
del onnx_model
os.system(
f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}"
) # temp_dir.cleanup()
temp_name = args.model_name.lower().replace("-", "").replace("_", "")
max_sequence_length = 16384 if "codellama" in temp_name else 4096 if "llama2" in temp_name else 2048


# Export decoder_with_past_model.onnx
input_ids, attn_mask, pos_ids, past_kv = get_sample_with_past_kv_inputs(
l_config, device, batch_size, sequence_length, world_size=world_size
input_ids, attn_mask, pos_ids, past_kv = get_merged_sample_with_past_kv_inputs(
l_config,
device,
batch_size,
sequence_length,
past_sequence_length,
max_seq_len=max_sequence_length,
use_fp16=False,
world_size=world_size,
)
temp_dir = args.output # tempfile.TemporaryDirectory()
temp_path = os.path.join(temp_dir, "temp.onnx") # os.path.join(temp_dir.name, "temp.onnx")
temp_dir = tempfile.TemporaryDirectory()
temp_path = os.path.join(temp_dir.name, "temp.onnx")
torch.onnx.dynamo_export(
llama, input_ids, attn_mask, pos_ids, past_kv, export_options=torch.onnx.ExportOptions(dynamic_shapes=True)
).save(temp_path)
Expand All @@ -190,9 +184,7 @@ def run_dynamo_export(
onnx_model = onnx.load_model(temp_path, load_external_data=True)
save_onnx_model(onnx_model, output_path, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx.data")
del onnx_model
os.system(
f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}"
) # temp_dir.cleanup()
temp_dir.cleanup()

logger.info(f"The {args.model_name} ONNX model has been successfully created with the Dynamo exporter!")

Expand Down Expand Up @@ -869,7 +861,7 @@ def main():

# Export to ONNX
if missing_separate_exports or missing_merged_export:
if args.use_dynamo_export and missing_separate_exports:
if args.use_dynamo_export:
logger.warning("Please ensure you have installed PyTorch, ONNX, and ONNX Script as follows.")
logger.warning("Step 1 - PyTorch nightly: https://pytorch.org/get-started/locally/")
logger.warning("Step 2 - ONNX weekly: https://pypi.org/project/onnx-weekly/")
Expand Down Expand Up @@ -902,109 +894,110 @@ def main():
decoder_merged_model_fp32_opt_path,
]

# Run the optimizer script
logger.info("Optimizing models...")
for orig_path, opt_path in zip(old_paths, new_paths):
kobby-kobbs marked this conversation as resolved.
Show resolved Hide resolved
if os.path.exists(orig_path):
optimize_export(args, l_config, input_path=orig_path, output_path=opt_path, world_size=world_size)

# Re-assign default FP32 model paths as their optimized versions
decoder_model_fp32_path = decoder_model_fp32_opt_path
decoder_with_past_model_fp32_path = decoder_with_past_model_fp32_opt_path
decoder_merged_model_fp32_path = decoder_merged_model_fp32_opt_path
old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path]

logger.info(
f"The {args.model_name} ONNX model has been successfully optimized with the ORT transformer optimizer script!"
)

# Change precision of exported models from FP32
if args.precision == Precision.FLOAT16:
new_paths = convert_to_float16(args, old_paths, rank)

elif args.precision == Precision.INT8:
decoder_model_int8_path = os.path.join(
args.output, f"rank_{rank}_{args.model_name}_decoder_model_int8.onnx"
)
decoder_with_past_model_int8_path = os.path.join(
args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int8.onnx"
)
decoder_merged_model_int8_path = os.path.join(
args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int8.onnx"
# Run the optimizer script, runs the torch as well.
kobby-kobbs marked this conversation as resolved.
Show resolved Hide resolved
if not args.use_dynamo_export:
kobby-kobbs marked this conversation as resolved.
Show resolved Hide resolved
kobby-kobbs marked this conversation as resolved.
Show resolved Hide resolved
logger.info("Optimizing models...")
for orig_path, opt_path in zip(old_paths, new_paths):
if os.path.exists(orig_path):
optimize_export(args, l_config, input_path=orig_path, output_path=opt_path, world_size=world_size)

# Re-assign default FP32 model paths as their optimized versions
decoder_model_fp32_path = decoder_model_fp32_opt_path
decoder_with_past_model_fp32_path = decoder_with_past_model_fp32_opt_path
decoder_merged_model_fp32_path = decoder_merged_model_fp32_opt_path
old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path]

logger.info(
f"The {args.model_name} ONNX model has been successfully optimized with the ORT transformer optimizer script!"
)
new_paths = [decoder_model_int8_path, decoder_with_past_model_int8_path, decoder_merged_model_int8_path]

if args.quantization_method == "smooth_quant":
if not args.no_merged:
logger.error("SmoothQuant must be used on separately exported models")
else:
logger.info(
f"Quantizing {decoder_model_fp32_path} and {decoder_with_past_model_fp32_path} to int8"
)
smooth_quant(args, old_paths[0], old_paths[1], new_paths[0], new_paths[1])
# Change precision of exported models from FP32
if args.precision == Precision.FLOAT16:
new_paths = convert_to_float16(args, old_paths, rank)
Fixed Show fixed Hide fixed

elif args.quantization_method == "quantize_dynamic":
logger.warning(
"The `quantize_dynamic` method is deprecated in favor of `smooth_quant` instead. Precision loss may be high with `quantize_dynamic`."
elif args.precision == Precision.INT8:
decoder_model_int8_path = os.path.join(
args.output, f"rank_{rank}_{args.model_name}_decoder_model_int8.onnx"
)
decoder_with_past_model_int8_path = os.path.join(
args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int8.onnx"
)
decoder_merged_model_int8_path = os.path.join(
args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int8.onnx"
)
new_paths = [decoder_model_int8_path, decoder_with_past_model_int8_path, decoder_merged_model_int8_path]

logger.info("Quantizing to int8...")
for fp32_path, int8_path in zip(old_paths, new_paths):
if os.path.exists(fp32_path):
ort_quantization.quantize_dynamic(
fp32_path,
int8_path,
op_types_to_quantize=(
["MatMul", "Gemm", "Gather"]
if args.quantize_embedding_layer
else ["MatMul", "Gemm"]
),
per_channel=args.quantize_per_channel,
reduce_range=args.quantize_reduce_range,
use_external_data_format=True,
extra_options={"MatMulConstBOnly": True},
)
if args.quantization_method == "smooth_quant":
if not args.no_merged:
logger.error("SmoothQuant must be used on separately exported models")
else:
logger.info(
f"The ONNX model at {fp32_path} has been quantized to int8 and saved at {int8_path}!"
f"Quantizing {decoder_model_fp32_path} and {decoder_with_past_model_fp32_path} to int8"
)
remove_existing_model(decoder_model_fp32_path)
smooth_quant(args, old_paths[0], old_paths[1], new_paths[0], new_paths[1])

logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!")
elif args.quantization_method == "quantize_dynamic":
logger.warning(
"The `quantize_dynamic` method is deprecated in favor of `smooth_quant` instead. Precision loss may be high with `quantize_dynamic`."
)

else:
raise Exception(f"Could not recognize {args.quantization_method} as a quantization method")
logger.info("Quantizing to int8...")
for fp32_path, int8_path in zip(old_paths, new_paths):
if os.path.exists(fp32_path):
ort_quantization.quantize_dynamic(
fp32_path,
int8_path,
op_types_to_quantize=(
["MatMul", "Gemm", "Gather"]
if args.quantize_embedding_layer
else ["MatMul", "Gemm"]
),
per_channel=args.quantize_per_channel,
reduce_range=args.quantize_reduce_range,
use_external_data_format=True,
extra_options={"MatMulConstBOnly": True},
)
logger.info(
f"The ONNX model at {fp32_path} has been quantized to int8 and saved at {int8_path}!"
)
remove_existing_model(decoder_model_fp32_path)

logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!")

elif args.precision == Precision.INT4:
if args.execution_provider != "cpu":
old_paths = convert_to_float16(args, old_paths, rank)
else:
raise Exception(f"Could not recognize {args.quantization_method} as a quantization method")

decoder_model_int4_path = os.path.join(
args.output, f"rank_{rank}_{args.model_name}_decoder_model_int4.onnx"
)
decoder_with_past_model_int4_path = os.path.join(
args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int4.onnx"
)
decoder_merged_model_int4_path = os.path.join(
args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int4.onnx"
)
new_paths = [decoder_model_int4_path, decoder_with_past_model_int4_path, decoder_merged_model_int4_path]

for fp_path, int4_path in zip(old_paths, new_paths):
if os.path.exists(fp_path):
model = onnx.load_model(fp_path, load_external_data=True)
quant = MatMul4BitsQuantizer(
model=model,
block_size=args.block_size,
is_symmetric=True,
accuracy_level=args.int4_accuracy_level,
nodes_to_exclude=[],
)
quant.process()
quant.model.save_model_to_file(int4_path, use_external_data_format=True)
del model
del quant
logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!")
remove_existing_model(fp_path)
elif args.precision == Precision.INT4:
if args.execution_provider != "cpu":
old_paths = convert_to_float16(args, old_paths, rank)

decoder_model_int4_path = os.path.join(
args.output, f"rank_{rank}_{args.model_name}_decoder_model_int4.onnx"
)
decoder_with_past_model_int4_path = os.path.join(
args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int4.onnx"
)
decoder_merged_model_int4_path = os.path.join(
args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int4.onnx"
)
new_paths = [decoder_model_int4_path, decoder_with_past_model_int4_path, decoder_merged_model_int4_path]

for fp_path, int4_path in zip(old_paths, new_paths):
if os.path.exists(fp_path):
model = onnx.load_model(fp_path, load_external_data=True)
quant = MatMul4BitsQuantizer(
model=model,
block_size=args.block_size,
is_symmetric=True,
accuracy_level=args.int4_accuracy_level,
nodes_to_exclude=[],
)
quant.process()
quant.model.save_model_to_file(int4_path, use_external_data_format=True)
del model
del quant
logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!")
remove_existing_model(fp_path)
barrier()

logger.info("Verifying parity on all ONNX models created")
kobby-kobbs marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -1056,4 +1049,4 @@ def main():


if __name__ == "__main__":
main()
main()
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool):
# input_ids: (batch_size, sequence_length)
# attention_mask: (batch_size, sequence_length)
# position_ids: (batch_size, sequence_length)

kobby-kobbs marked this conversation as resolved.
Show resolved Hide resolved
def get_sample_inputs(
config: AutoConfig,
device: torch.device,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,4 @@ def main(argv: list[str] = []): # noqa: B006
np.random.seed(seed)
torch.manual_seed(seed)
main()

kobby-kobbs marked this conversation as resolved.
Show resolved Hide resolved
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Loading