Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce LLaMA memory usage #18181

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 21 additions & 64 deletions onnxruntime/python/tools/transformers/models/llama/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
import onnx
import psutil
import torch
from benchmark_helper import setup_logger
from benchmark_helper import measure_memory, setup_logger
from llama_inputs import (
add_io_bindings,
convert_inputs_for_ort,
get_merged_sample_with_past_kv_inputs,
get_msft_sample_inputs,
Expand All @@ -25,7 +26,6 @@
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer

import onnxruntime as ort
from onnxruntime.transformers.benchmark_helper import measure_memory
kunal-vaishnavi marked this conversation as resolved.
Show resolved Hide resolved

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -95,7 +95,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
args.batch_size,
seq_len=args.sequence_length,
past_seq_len=0,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
engine="pt",
return_dict=True,
)
iter_inputs = get_merged_sample_with_past_kv_inputs(
Expand All @@ -104,7 +106,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
args.batch_size,
seq_len=1,
past_seq_len=args.sequence_length,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
engine="pt",
return_dict=True,
)

Expand All @@ -116,7 +120,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
args.batch_size,
seq_len=args.sequence_length,
past_seq_len=0,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
engine="ort",
return_dict=True,
)
iter_inputs = get_merged_sample_with_past_kv_inputs(
Expand All @@ -125,26 +131,10 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
args.batch_size,
seq_len=1,
past_seq_len=args.sequence_length,
use_fp16=args.use_fp16,
return_dict=True,
)
init_inputs = convert_inputs_for_ort(
init_inputs,
use_fp16=args.use_fp16,
use_buffer_share=args.past_present_share_buffer,
past_seq_len=0,
max_seq_len=max_seq_len,
device=args.device,
device_id=args.device_id,
)
iter_inputs = convert_inputs_for_ort(
iter_inputs,
use_fp16=args.use_fp16,
use_buffer_share=args.past_present_share_buffer,
past_seq_len=args.sequence_length,
max_seq_len=max_seq_len,
device=args.device,
device_id=args.device_id,
engine="ort",
return_dict=True,
)

elif args.benchmark_type == "ort-msft":
Expand All @@ -156,6 +146,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
args.batch_size,
past_seq_len=0,
seq_len=args.sequence_length,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
split_kv=split_kv,
)
Expand All @@ -164,26 +155,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
args.batch_size,
past_seq_len=args.sequence_length,
seq_len=1,
use_fp16=args.use_fp16,
split_kv=split_kv,
)
init_inputs = convert_inputs_for_ort(
init_inputs,
use_fp16=args.use_fp16,
use_buffer_share=args.past_present_share_buffer,
past_seq_len=0,
max_seq_len=max_seq_len,
device=args.device,
device_id=args.device_id,
)
iter_inputs = convert_inputs_for_ort(
iter_inputs,
use_fp16=args.use_fp16,
use_buffer_share=args.past_present_share_buffer,
past_seq_len=args.sequence_length,
max_seq_len=max_seq_len,
device=args.device,
device_id=args.device_id,
split_kv=split_kv,
)

else:
Expand Down Expand Up @@ -449,7 +423,7 @@ def get_logits(inputs):


def run_ort_inference(args, init_inputs, iter_inputs, model):
def prepare_ort_inputs(inputs):
def prepare_ort_inputs(inputs, kv_cache_ortvalues):
# Check that all model inputs will be provided
model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs()))
user_inputs = set(inputs.keys())
Expand All @@ -467,29 +441,11 @@ def prepare_ort_inputs(inputs):

# Add IO bindings for non-CPU execution providers
if args.device != "cpu":
io_binding = model.io_binding()

for k, v in inputs.items():
if args.past_present_share_buffer:
# Bind all OrtValue inputs to device
io_binding.bind_ortvalue_input(k, v)
else:
io_binding.bind_cpu_input(k, v)

for output in model.get_outputs():
name = output.name
if args.past_present_share_buffer and ("out" in name or "present" in name):
# Bind present KV cache outputs to OrtValue with buffer sharing
io_binding.bind_ortvalue_output(
name, inputs[name.replace("out", "cache").replace("present", "past_key_values")]
)
else:
io_binding.bind_output(name, device_type=args.device, device_id=args.device_id)

io_binding, kv_cache_ortvalues = add_io_bindings(model, inputs, args.device, int(args.device_id), kv_cache_ortvalues)
setattr(args, "io_binding", io_binding) # noqa: B010
return io_binding
return io_binding, kv_cache_ortvalues

return inputs
return inputs, kv_cache_ortvalues

def with_io_binding(io_binding):
# Inference pass with IO binding
Expand All @@ -501,9 +457,10 @@ def without_io_binding(inputs):
return outputs

generate_fn = with_io_binding if args.device != "cpu" else without_io_binding
kv_cache_ortvalues = {}

if args.profile:
ort_init_inputs = prepare_ort_inputs(init_inputs)
ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues)
new_logname = profile_fn(args, generate_fn, ort_init_inputs, "prompt")

# Turn profiling off to stop appending to log file
Expand All @@ -513,7 +470,7 @@ def without_io_binding(inputs):

# Re-initialize model for new log file instead of appending to old log file
model = get_model(args)
ort_iter_inputs = prepare_ort_inputs(iter_inputs)
ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues)
new_logname = profile_fn(args, generate_fn, ort_iter_inputs, "token")

# Turn profiling off to stop appending to log
Expand All @@ -524,12 +481,12 @@ def without_io_binding(inputs):

# ORT evaluations
logger.info("\nEvaluating `model(inputs)` step to get past_key_values")
ort_init_inputs = prepare_ort_inputs(init_inputs)
ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues)
time_fn(args, generate_fn, ort_init_inputs)
measure_fn(args, generate_fn, ort_init_inputs)

logger.info("\nEvaluating `model(inputs)` step with past_key_values")
ort_iter_inputs = prepare_ort_inputs(iter_inputs)
ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues)
time_fn(args, generate_fn, ort_iter_inputs)
measure_fn(args, generate_fn, ort_iter_inputs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,7 @@ def main():
run_torchscript_separate_export(args, l_config, llama)
else:
run_torchscript_merged_export(args, l_config, llama)
del llama # Delete LLaMA model from memory since it will be loaded again during parity check

# Set model paths to store FP32 optimized model
decoder_model_fp32_opt_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32_opt.onnx")
Expand Down Expand Up @@ -811,7 +812,6 @@ def main():
logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!")
remove_existing_model(fp_path)

del llama # Delete LLaMA model from memory since it will be loaded again during parity check
logger.info("Verifying parity on all ONNX models created")

# Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
Expand Down
Loading
Loading