Skip to content

Commit

Permalink
dynamo export success
Browse files Browse the repository at this point in the history
  • Loading branch information
kobby-kobbs committed Jul 1, 2024
1 parent 63c13a4 commit 65423a2
Showing 1 changed file with 36 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import onnx
import torch
import tempfile
from benchmark_helper import Precision, prepare_environment, setup_logger
from convert_generation import replace_mha_with_gqa
from dist_settings import barrier, get_rank, get_size, init_dist
Expand Down Expand Up @@ -149,35 +150,45 @@ def run_dynamo_export(
config.capture_scalar_outputs = True

# Dummy values for export
batch_size, sequence_length = 2, 8
device = torch.device("cpu")
batch_size, sequence_length, past_sequence_length = 2, 8, 0
device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu")

# Export decoder_model.onnx
input_ids, attn_mask, pos_ids = get_sample_inputs(l_config, device, batch_size, sequence_length)
temp_dir = args.output # tempfile.TemporaryDirectory()
temp_path = os.path.join(temp_dir, "temp.onnx") # os.path.join(temp_dir.name, "temp.onnx")
torch.onnx.dynamo_export(
llama, input_ids, attn_mask, pos_ids, export_options=torch.onnx.ExportOptions(dynamic_shapes=True)
).save(temp_path)
# input_ids, attn_mask, pos_ids = get_sample_inputs(l_config, device, batch_size, sequence_length)
# temp_dir = tempfile.TemporaryDirectory()
# temp_path = os.path.join(temp_dir.name, "temp.onnx")
# torch.onnx.dynamo_export(
# llama, input_ids, attn_mask, pos_ids, export_options=torch.onnx.ExportOptions(dynamic_shapes=True)
# ).save(temp_path)

# Check decoder_model.onnx and save all external data to one file
onnx.checker.check_model(temp_path)
onnx.shape_inference.infer_shapes_path(temp_path)
# # Check decoder_model.onnx and save all external data to one file
# onnx.checker.check_model(temp_path)
# onnx.shape_inference.infer_shapes_path(temp_path)

# output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx")
# onnx_model = onnx.load_model(temp_path, load_external_data=True)
# save_onnx_model(onnx_model, output_path, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx.data")
# del onnx_model
# temp_dir.cleanup()


temp_name = args.model_name.lower().replace("-", "").replace("_", "")
max_sequence_length = 16384 if "codellama" in temp_name else 4096 if "llama2" in temp_name else 2048

output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx")
onnx_model = onnx.load_model(temp_path, load_external_data=True)
save_onnx_model(onnx_model, output_path, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx.data")
del onnx_model
os.system(
f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}"
) # temp_dir.cleanup()

# Export decoder_with_past_model.onnx
input_ids, attn_mask, pos_ids, past_kv = get_sample_with_past_kv_inputs(
l_config, device, batch_size, sequence_length, world_size=world_size
input_ids, attn_mask, pos_ids, past_kv = get_merged_sample_with_past_kv_inputs(
l_config,
device,
batch_size,
sequence_length,
past_sequence_length,
max_seq_len=max_sequence_length,
use_fp16=False,
world_size=world_size,
)
temp_dir = args.output # tempfile.TemporaryDirectory()
temp_path = os.path.join(temp_dir, "temp.onnx") # os.path.join(temp_dir.name, "temp.onnx")
temp_dir = tempfile.TemporaryDirectory()
temp_path = os.path.join(temp_dir.name, "temp.onnx")
torch.onnx.dynamo_export(
llama, input_ids, attn_mask, pos_ids, past_kv, export_options=torch.onnx.ExportOptions(dynamic_shapes=True)
).save(temp_path)
Expand All @@ -190,9 +201,7 @@ def run_dynamo_export(
onnx_model = onnx.load_model(temp_path, load_external_data=True)
save_onnx_model(onnx_model, output_path, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx.data")
del onnx_model
os.system(
f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}"
) # temp_dir.cleanup()
temp_dir.cleanup()

logger.info(f"The {args.model_name} ONNX model has been successfully created with the Dynamo exporter!")

Expand Down Expand Up @@ -869,7 +878,7 @@ def main():

# Export to ONNX
if missing_separate_exports or missing_merged_export:
if args.use_dynamo_export and missing_separate_exports:
if args.use_dynamo_export and missing_merged_export:
logger.warning("Please ensure you have installed PyTorch, ONNX, and ONNX Script as follows.")
logger.warning("Step 1 - PyTorch nightly: https://pytorch.org/get-started/locally/")
logger.warning("Step 2 - ONNX weekly: https://pypi.org/project/onnx-weekly/")
Expand Down Expand Up @@ -1056,4 +1065,4 @@ def main():


if __name__ == "__main__":
main()
main()

0 comments on commit 65423a2

Please sign in to comment.