Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabled Dynamo exporter #21713

Merged
merged 9 commits into from
Aug 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import shutil
import subprocess
import sys
import tempfile
from itertools import chain

import onnx
Expand Down Expand Up @@ -113,34 +114,6 @@ def save_onnx_model(onnx_model: onnx.ModelProto, output_path: str, data_path: st
)


# Notes:
# 1) Dynamo export will not work automatically until this issue is resolved: https://github.com/microsoft/onnxscript/issues/493
#
# 2) Dynamo export will run manually if you set the ONNX file path to the same path that you use to save the model after export.
# In other words, the value of `temp_path` should be set as the ONNX file path. You can open the issue in your browser to find
# the location in ONNX Script where you have to make this change.
#
# Once the issue is resolved, we hope to modify the code below as follows for each export.
#
# Before:
# temp_dir = args.output
# temp_path = os.path.join(temp_dir, "temp.onnx")
# ...
# ...
# ...
# del onnx_model
# os.system(f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}")
#
#
# After:
# temp_dir = tempfile.TemporaryDirectory()
# temp_path = os.path.join(temp_dir.name, "temp.onnx")
# ...
# ...
# ...
# del onnx_model
# temp_dir.cleanup()
#
def run_dynamo_export(
args: argparse.Namespace, l_config: AutoConfig, llama: AutoModelForCausalLM, rank: int = 0, world_size: int = 1
):
Expand All @@ -149,35 +122,25 @@ def run_dynamo_export(
config.capture_scalar_outputs = True

# Dummy values for export
batch_size, sequence_length = 2, 8
device = torch.device("cpu")

# Export decoder_model.onnx
input_ids, attn_mask, pos_ids = get_sample_inputs(l_config, device, batch_size, sequence_length)
temp_dir = args.output # tempfile.TemporaryDirectory()
temp_path = os.path.join(temp_dir, "temp.onnx") # os.path.join(temp_dir.name, "temp.onnx")
torch.onnx.dynamo_export(
llama, input_ids, attn_mask, pos_ids, export_options=torch.onnx.ExportOptions(dynamic_shapes=True)
).save(temp_path)

# Check decoder_model.onnx and save all external data to one file
onnx.checker.check_model(temp_path)
onnx.shape_inference.infer_shapes_path(temp_path)
batch_size, sequence_length, past_sequence_length = 2, 8, 0
device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu")

output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx")
onnx_model = onnx.load_model(temp_path, load_external_data=True)
save_onnx_model(onnx_model, output_path, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx.data")
del onnx_model
os.system(
f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}"
) # temp_dir.cleanup()
temp_name = args.model_name.lower().replace("-", "").replace("_", "")
max_sequence_length = 16384 if "codellama" in temp_name else 4096 if "llama2" in temp_name else 2048

# Export decoder_with_past_model.onnx
input_ids, attn_mask, pos_ids, past_kv = get_sample_with_past_kv_inputs(
l_config, device, batch_size, sequence_length, world_size=world_size
input_ids, attn_mask, pos_ids, past_kv = get_merged_sample_with_past_kv_inputs(
l_config,
device,
batch_size,
sequence_length,
past_sequence_length,
max_seq_len=max_sequence_length,
use_fp16=False,
world_size=world_size,
)
temp_dir = args.output # tempfile.TemporaryDirectory()
temp_path = os.path.join(temp_dir, "temp.onnx") # os.path.join(temp_dir.name, "temp.onnx")
temp_dir = tempfile.TemporaryDirectory()
temp_path = os.path.join(temp_dir.name, "temp.onnx")
torch.onnx.dynamo_export(
llama, input_ids, attn_mask, pos_ids, past_kv, export_options=torch.onnx.ExportOptions(dynamic_shapes=True)
).save(temp_path)
Expand All @@ -190,9 +153,7 @@ def run_dynamo_export(
onnx_model = onnx.load_model(temp_path, load_external_data=True)
save_onnx_model(onnx_model, output_path, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx.data")
del onnx_model
os.system(
f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}"
) # temp_dir.cleanup()
temp_dir.cleanup()

logger.info(f"The {args.model_name} ONNX model has been successfully created with the Dynamo exporter!")

Expand Down Expand Up @@ -869,7 +830,7 @@ def main():

# Export to ONNX
if missing_separate_exports or missing_merged_export:
if args.use_dynamo_export and missing_separate_exports:
if args.use_dynamo_export:
logger.warning("Please ensure you have installed PyTorch, ONNX, and ONNX Script as follows.")
logger.warning("Step 1 - PyTorch nightly: https://pytorch.org/get-started/locally/")
logger.warning("Step 2 - ONNX weekly: https://pypi.org/project/onnx-weekly/")
Expand Down Expand Up @@ -902,7 +863,10 @@ def main():
decoder_merged_model_fp32_opt_path,
]

# Run the optimizer script
if args.use_dynamo_export:
continue

# Run the optimizer script.
logger.info("Optimizing models...")
for orig_path, opt_path in zip(old_paths, new_paths):
kobby-kobbs marked this conversation as resolved.
Show resolved Hide resolved
if os.path.exists(orig_path):
Expand Down Expand Up @@ -1007,6 +971,9 @@ def main():
remove_existing_model(fp_path)
barrier()

if args.use_dynamo_export:
kobby-kobbs marked this conversation as resolved.
Show resolved Hide resolved
return

logger.info("Verifying parity on all ONNX models created")
kobby-kobbs marked this conversation as resolved.
Show resolved Hide resolved

# Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
Expand Down
Loading