Skip to content

Commit

Permalink
Added changes to enable dynamo exporter
Browse files Browse the repository at this point in the history
  • Loading branch information
kobby-kobbs committed Aug 12, 2024
1 parent 16c94c8 commit 6af5245
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def run_dynamo_export(
batch_size, sequence_length, past_sequence_length = 2, 8, 0
device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu")

# Export decoder_model.onnx
# Export decoder_model.onnx, commented out for now to use dynamo_export
# input_ids, attn_mask, pos_ids = get_sample_inputs(l_config, device, batch_size, sequence_length)
# temp_dir = tempfile.TemporaryDirectory()
# temp_path = os.path.join(temp_dir.name, "temp.onnx")
Expand Down Expand Up @@ -911,7 +911,7 @@ def main():
decoder_merged_model_fp32_opt_path,
]

# # Run the optimizer script, runs the torch as well
# # Run the optimizer script, runs the torch as well. Keeping this block commented makes sure only Dynamo export is used.
# logger.info("Optimizing models...")
# for orig_path, opt_path in zip(old_paths, new_paths):
# if os.path.exists(orig_path):
Expand Down
53 changes: 53 additions & 0 deletions onnxruntime/python/tools/transformers/models/llama/llama_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool):
# input_ids: (batch_size, sequence_length)
# attention_mask: (batch_size, sequence_length)
# position_ids: (batch_size, sequence_length)

def get_sample_inputs(
config: AutoConfig,
device: torch.device,
Expand Down Expand Up @@ -171,6 +172,58 @@ def get_merged_sample_with_past_kv_inputs(

return inputs

def get_dynamo_inputs(
config: AutoConfig,
device: torch.device,
batch_size: int,
seq_len: int,
past_seq_len: int,
max_seq_len: int,
use_fp16: bool = False,
use_buffer_share: bool = False,
engine: str = "pt",
return_dict: bool = False,
world_size: int = 1,
):
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
attention_mask = torch.ones(batch_size, past_seq_len + seq_len, dtype=torch.int64)
# position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation
position_ids = get_position_ids(attention_mask, use_past_kv=(past_seq_len != 0))
past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size)

# Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
past_kv = (
flatten_past_kv_inputs(past_kv)
if engine == "ort"
else list(map(lambda kv: (kv[0].to(device), kv[1].to(device)), past_kv))
)

if not return_dict:
# For export
assert isinstance(past_kv, list)
return (input_ids, attention_mask, position_ids, past_kv)

inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
if engine == "ort":
assert isinstance(past_kv, dict)
inputs.update(past_kv)

if use_buffer_share:
inputs = enable_past_present_share_buffer(inputs, past_seq_len, max_seq_len)

else:
assert isinstance(past_kv, list)
inputs["past_key_values"] = past_kv

return inputs


# Inputs for Microsoft export from https://github.com/microsoft/Llama-2-Onnx
def get_msft_sample_inputs(
Expand Down
16 changes: 14 additions & 2 deletions onnxruntime/python/tools/transformers/models/llama/llama_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
get_sample_inputs,
get_sample_with_past_kv_inputs,
verify_ort_inputs,
get_dynamo_inputs
)
from llama_torch import setup_torch_model
from transformers import AutoConfig
Expand All @@ -41,8 +42,10 @@ def get_inputs(args: argparse.Namespace, config: AutoConfig):
world_size = get_size()
batch_size = 2
past_sequence_length, sequence_length, max_sequence_length = get_sequence_lengths(args, config)

if args.merged:

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

if args.dynamo:
inputs = get_dynamo_inputs(config, args.device, batch_size, sequence_length, return_dict=True)

Check failure

Code scanning / CodeQL

Wrong number of arguments in a call Error

Call to
function get_dynamo_inputs
with too few arguments; should be no fewer than 6.
elif args.merged:
inputs = get_merged_sample_with_past_kv_inputs(
config,
args.device,
Expand Down Expand Up @@ -166,6 +169,7 @@ def verify_parity(
def get_args(argv: list[str]):
parser = argparse.ArgumentParser()


Check warning

Code scanning / lintrunner

RUFF/W293 Warning

parser.add_argument(
"-m",
"--model_name",
Expand Down Expand Up @@ -236,6 +240,14 @@ def get_args(argv: list[str]):
choices=["int4", "int8", "fp16", "fp32"],
help="Precision of model",
)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

parser.add_argument(
"--dynamo",
action="store_true",
help="Use Dynamo model inputs for parity check",
)
parser.set_defaults(dynamo=False)


parser.add_argument(
"--cache_dir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def setup_torch_model(args, location, auth, torch_dtype=torch.float32, device=No
)
l_config.use_cache = True
l_config._attn_implementation = "eager" # "eager" uses LlamaAttention for attention layer
# l_config.num_hidden_layers = 1
llama = AutoModelForCausalLM.from_pretrained(
location,
use_auth_token=auth,
Expand Down

0 comments on commit 6af5245

Please sign in to comment.