Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
final pr changes
Browse files Browse the repository at this point in the history
kobby-kobbs committed Aug 16, 2024
1 parent 174888c commit f1472d8
Showing 3 changed files with 5 additions and 38 deletions.
Original file line number Diff line number Diff line change
@@ -11,11 +11,11 @@
import shutil
import subprocess
import sys
import tempfile
from itertools import chain

import onnx
import torch
import tempfile
from benchmark_helper import Precision, prepare_environment, setup_logger
from convert_generation import replace_mha_with_gqa
from dist_settings import barrier, get_rank, get_size, init_dist
@@ -114,34 +114,6 @@ def save_onnx_model(onnx_model: onnx.ModelProto, output_path: str, data_path: st
)


# Notes:
# 1) Dynamo export will not work automatically until this issue is resolved: https://github.com/microsoft/onnxscript/issues/493
#
# 2) Dynamo export will run manually if you set the ONNX file path to the same path that you use to save the model after export.
# In other words, the value of `temp_path` should be set as the ONNX file path. You can open the issue in your browser to find
# the location in ONNX Script where you have to make this change.
#
# Once the issue is resolved, we hope to modify the code below as follows for each export.
#
# Before:
# temp_dir = args.output
# temp_path = os.path.join(temp_dir, "temp.onnx")
# ...
# ...
# ...
# del onnx_model
# os.system(f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}")
#
#
# After:
# temp_dir = tempfile.TemporaryDirectory()
# temp_path = os.path.join(temp_dir.name, "temp.onnx")
# ...
# ...
# ...
# del onnx_model
# temp_dir.cleanup()
#
def run_dynamo_export(
args: argparse.Namespace, l_config: AutoConfig, llama: AutoModelForCausalLM, rank: int = 0, world_size: int = 1
):
@@ -153,12 +125,9 @@ def run_dynamo_export(
batch_size, sequence_length, past_sequence_length = 2, 8, 0
device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu")



temp_name = args.model_name.lower().replace("-", "").replace("_", "")
max_sequence_length = 16384 if "codellama" in temp_name else 4096 if "llama2" in temp_name else 2048


# Export decoder_with_past_model.onnx
input_ids, attn_mask, pos_ids, past_kv = get_merged_sample_with_past_kv_inputs(
l_config,
@@ -894,10 +863,10 @@ def main():
decoder_merged_model_fp32_opt_path,
]

# Run the optimizer script, runs the torch as well.
if not args.use_dynamo_export:
if args.use_dynamo_export:
continue
# Rest of optimizer code for TorchScript exported model

# Run the optimizer script.
logger.info("Optimizing models...")
for orig_path, opt_path in zip(old_paths, new_paths):
if os.path.exists(orig_path):
@@ -1053,4 +1022,4 @@ def main():


if __name__ == "__main__":
main()
main()
Original file line number Diff line number Diff line change
@@ -28,7 +28,6 @@ def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool):
# input_ids: (batch_size, sequence_length)
# attention_mask: (batch_size, sequence_length)
# position_ids: (batch_size, sequence_length)

def get_sample_inputs(
config: AutoConfig,
device: torch.device,
Original file line number Diff line number Diff line change
@@ -307,4 +307,3 @@ def main(argv: list[str] = []): # noqa: B006
np.random.seed(seed)
torch.manual_seed(seed)
main()

0 comments on commit f1472d8

Please sign in to comment.