Skip to content

Commit

Permalink
Reduce LLaMA memory usage (microsoft#18181)
Browse files Browse the repository at this point in the history
### Description
This PR reduces the memory usage when exporting and benchmarking LLaMA.



### Motivation and Context
- Exporting: The PyTorch model is deleted from memory after a successful
export instead of deleting it from memory after exporting + converting
the ONNX model to the desired precision.
- Benchmarking: In the ONNX model with GroupQueryAttention, the KV cache
inputs use the same GPU memory for both the prompt and token generation
benchmarks.
  • Loading branch information
kunal-vaishnavi authored Nov 1, 2023
1 parent 2b95e74 commit d1b85f5
Show file tree
Hide file tree
Showing 4 changed files with 248 additions and 186 deletions.
104 changes: 36 additions & 68 deletions onnxruntime/python/tools/transformers/models/llama/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
import onnx
import psutil
import torch
from benchmark_helper import setup_logger
from llama_inputs import (
convert_inputs_for_ort,
add_io_bindings,
get_merged_sample_with_past_kv_inputs,
get_msft_sample_inputs,
get_sample_inputs,
Expand All @@ -25,7 +24,7 @@
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer

import onnxruntime as ort
from onnxruntime.transformers.benchmark_helper import measure_memory
from onnxruntime.transformers.benchmark_helper import measure_memory, setup_logger

logger = logging.getLogger(__name__)

Expand All @@ -48,9 +47,19 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
init_inputs, iter_inputs = None, None

# For past_present_share_buffer:
# Set max_seq_len to 2048 for Hugging Face model since that is the default value
# Set max_seq_len to 2048 for Microsoft model since that is the max value currently supported
max_seq_len = 2048
# Set max_seq_len to 16384 for CodeLLaMA (finetuned variant of LLaMA-2)
# Set max_seq_len to 4096 for Hugging Face LLaMA-2 model since that is the default value
# Set max_seq_len to 2048 for Microsoft LLaMA-2 model since that is the max value currently supported
temp_name = args.model_name.lower().replace("-", "").replace("_", "")
max_seq_len = (
2048
if args.benchmark_type == "ort-msft"
else 16384
if "codellama" in temp_name
else 4096
if "llama2" in temp_name
else 2048
)

if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
init_inputs = get_sample_inputs(
Expand Down Expand Up @@ -95,7 +104,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
args.batch_size,
seq_len=args.sequence_length,
past_seq_len=0,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
engine="pt",
return_dict=True,
)
iter_inputs = get_merged_sample_with_past_kv_inputs(
Expand All @@ -104,7 +115,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
args.batch_size,
seq_len=1,
past_seq_len=args.sequence_length,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
engine="pt",
return_dict=True,
)

Expand All @@ -116,7 +129,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
args.batch_size,
seq_len=args.sequence_length,
past_seq_len=0,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
engine="ort",
return_dict=True,
)
iter_inputs = get_merged_sample_with_past_kv_inputs(
Expand All @@ -125,26 +140,10 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
args.batch_size,
seq_len=1,
past_seq_len=args.sequence_length,
use_fp16=args.use_fp16,
return_dict=True,
)
init_inputs = convert_inputs_for_ort(
init_inputs,
use_fp16=args.use_fp16,
use_buffer_share=args.past_present_share_buffer,
past_seq_len=0,
max_seq_len=max_seq_len,
device=args.device,
device_id=args.device_id,
)
iter_inputs = convert_inputs_for_ort(
iter_inputs,
use_fp16=args.use_fp16,
use_buffer_share=args.past_present_share_buffer,
past_seq_len=args.sequence_length,
max_seq_len=max_seq_len,
device=args.device,
device_id=args.device_id,
engine="ort",
return_dict=True,
)

elif args.benchmark_type == "ort-msft":
Expand All @@ -156,6 +155,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
args.batch_size,
past_seq_len=0,
seq_len=args.sequence_length,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
split_kv=split_kv,
)
Expand All @@ -164,26 +164,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
args.batch_size,
past_seq_len=args.sequence_length,
seq_len=1,
use_fp16=args.use_fp16,
split_kv=split_kv,
)
init_inputs = convert_inputs_for_ort(
init_inputs,
use_fp16=args.use_fp16,
use_buffer_share=args.past_present_share_buffer,
past_seq_len=0,
max_seq_len=max_seq_len,
device=args.device,
device_id=args.device_id,
)
iter_inputs = convert_inputs_for_ort(
iter_inputs,
use_fp16=args.use_fp16,
use_buffer_share=args.past_present_share_buffer,
past_seq_len=args.sequence_length,
max_seq_len=max_seq_len,
device=args.device,
device_id=args.device_id,
split_kv=split_kv,
)

else:
Expand Down Expand Up @@ -449,7 +432,7 @@ def get_logits(inputs):


def run_ort_inference(args, init_inputs, iter_inputs, model):
def prepare_ort_inputs(inputs):
def prepare_ort_inputs(inputs, kv_cache_ortvalues):
# Check that all model inputs will be provided
model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs()))
user_inputs = set(inputs.keys())
Expand All @@ -467,29 +450,13 @@ def prepare_ort_inputs(inputs):

# Add IO bindings for non-CPU execution providers
if args.device != "cpu":
io_binding = model.io_binding()

for k, v in inputs.items():
if args.past_present_share_buffer:
# Bind all OrtValue inputs to device
io_binding.bind_ortvalue_input(k, v)
else:
io_binding.bind_cpu_input(k, v)

for output in model.get_outputs():
name = output.name
if args.past_present_share_buffer and ("out" in name or "present" in name):
# Bind present KV cache outputs to OrtValue with buffer sharing
io_binding.bind_ortvalue_output(
name, inputs[name.replace("out", "cache").replace("present", "past_key_values")]
)
else:
io_binding.bind_output(name, device_type=args.device, device_id=args.device_id)

io_binding, kv_cache_ortvalues = add_io_bindings(
model, inputs, args.device, int(args.device_id), kv_cache_ortvalues
)
setattr(args, "io_binding", io_binding) # noqa: B010
return io_binding
return io_binding, kv_cache_ortvalues

return inputs
return inputs, kv_cache_ortvalues

def with_io_binding(io_binding):
# Inference pass with IO binding
Expand All @@ -501,9 +468,10 @@ def without_io_binding(inputs):
return outputs

generate_fn = with_io_binding if args.device != "cpu" else without_io_binding
kv_cache_ortvalues = {}

if args.profile:
ort_init_inputs = prepare_ort_inputs(init_inputs)
ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues)
new_logname = profile_fn(args, generate_fn, ort_init_inputs, "prompt")

# Turn profiling off to stop appending to log file
Expand All @@ -513,7 +481,7 @@ def without_io_binding(inputs):

# Re-initialize model for new log file instead of appending to old log file
model = get_model(args)
ort_iter_inputs = prepare_ort_inputs(iter_inputs)
ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues)
new_logname = profile_fn(args, generate_fn, ort_iter_inputs, "token")

# Turn profiling off to stop appending to log
Expand All @@ -524,12 +492,12 @@ def without_io_binding(inputs):

# ORT evaluations
logger.info("\nEvaluating `model(inputs)` step to get past_key_values")
ort_init_inputs = prepare_ort_inputs(init_inputs)
ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues)
time_fn(args, generate_fn, ort_init_inputs)
measure_fn(args, generate_fn, ort_init_inputs)

logger.info("\nEvaluating `model(inputs)` step with past_key_values")
ort_iter_inputs = prepare_ort_inputs(iter_inputs)
ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues)
time_fn(args, generate_fn, ort_iter_inputs)
measure_fn(args, generate_fn, ort_iter_inputs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,7 @@ def main():
run_torchscript_separate_export(args, l_config, llama)
else:
run_torchscript_merged_export(args, l_config, llama)
del llama # Delete LLaMA model from memory since it will be loaded again during parity check

# Set model paths to store FP32 optimized model
decoder_model_fp32_opt_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32_opt.onnx")
Expand Down Expand Up @@ -811,7 +812,6 @@ def main():
logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!")
remove_existing_model(fp_path)

del llama # Delete LLaMA model from memory since it will be loaded again during parity check
logger.info("Verifying parity on all ONNX models created")

# Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
Expand Down
Loading

0 comments on commit d1b85f5

Please sign in to comment.