diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark.py b/onnxruntime/python/tools/transformers/models/llama/benchmark.py index a721979eb0bcb..245ff3dfe7f9d 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark.py @@ -11,9 +11,8 @@ import onnx import psutil import torch -from benchmark_helper import setup_logger from llama_inputs import ( - convert_inputs_for_ort, + add_io_bindings, get_merged_sample_with_past_kv_inputs, get_msft_sample_inputs, get_sample_inputs, @@ -25,7 +24,7 @@ from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer import onnxruntime as ort -from onnxruntime.transformers.benchmark_helper import measure_memory +from onnxruntime.transformers.benchmark_helper import measure_memory, setup_logger logger = logging.getLogger(__name__) @@ -48,9 +47,19 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): init_inputs, iter_inputs = None, None # For past_present_share_buffer: - # Set max_seq_len to 2048 for Hugging Face model since that is the default value - # Set max_seq_len to 2048 for Microsoft model since that is the max value currently supported - max_seq_len = 2048 + # Set max_seq_len to 16384 for CodeLLaMA (finetuned variant of LLaMA-2) + # Set max_seq_len to 4096 for Hugging Face LLaMA-2 model since that is the default value + # Set max_seq_len to 2048 for Microsoft LLaMA-2 model since that is the max value currently supported + temp_name = args.model_name.lower().replace("-", "").replace("_", "") + max_seq_len = ( + 2048 + if args.benchmark_type == "ort-msft" + else 16384 + if "codellama" in temp_name + else 4096 + if "llama2" in temp_name + else 2048 + ) if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}: init_inputs = get_sample_inputs( @@ -95,7 +104,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): args.batch_size, seq_len=args.sequence_length, past_seq_len=0, + max_seq_len=max_seq_len, use_fp16=args.use_fp16, + engine="pt", return_dict=True, ) iter_inputs = get_merged_sample_with_past_kv_inputs( @@ -104,7 +115,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): args.batch_size, seq_len=1, past_seq_len=args.sequence_length, + max_seq_len=max_seq_len, use_fp16=args.use_fp16, + engine="pt", return_dict=True, ) @@ -116,7 +129,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): args.batch_size, seq_len=args.sequence_length, past_seq_len=0, + max_seq_len=max_seq_len, use_fp16=args.use_fp16, + engine="ort", return_dict=True, ) iter_inputs = get_merged_sample_with_past_kv_inputs( @@ -125,26 +140,10 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): args.batch_size, seq_len=1, past_seq_len=args.sequence_length, - use_fp16=args.use_fp16, - return_dict=True, - ) - init_inputs = convert_inputs_for_ort( - init_inputs, - use_fp16=args.use_fp16, - use_buffer_share=args.past_present_share_buffer, - past_seq_len=0, max_seq_len=max_seq_len, - device=args.device, - device_id=args.device_id, - ) - iter_inputs = convert_inputs_for_ort( - iter_inputs, use_fp16=args.use_fp16, - use_buffer_share=args.past_present_share_buffer, - past_seq_len=args.sequence_length, - max_seq_len=max_seq_len, - device=args.device, - device_id=args.device_id, + engine="ort", + return_dict=True, ) elif args.benchmark_type == "ort-msft": @@ -156,6 +155,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): args.batch_size, past_seq_len=0, seq_len=args.sequence_length, + max_seq_len=max_seq_len, use_fp16=args.use_fp16, split_kv=split_kv, ) @@ -164,26 +164,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): args.batch_size, past_seq_len=args.sequence_length, seq_len=1, - use_fp16=args.use_fp16, - split_kv=split_kv, - ) - init_inputs = convert_inputs_for_ort( - init_inputs, - use_fp16=args.use_fp16, - use_buffer_share=args.past_present_share_buffer, - past_seq_len=0, max_seq_len=max_seq_len, - device=args.device, - device_id=args.device_id, - ) - iter_inputs = convert_inputs_for_ort( - iter_inputs, use_fp16=args.use_fp16, - use_buffer_share=args.past_present_share_buffer, - past_seq_len=args.sequence_length, - max_seq_len=max_seq_len, - device=args.device, - device_id=args.device_id, + split_kv=split_kv, ) else: @@ -449,7 +432,7 @@ def get_logits(inputs): def run_ort_inference(args, init_inputs, iter_inputs, model): - def prepare_ort_inputs(inputs): + def prepare_ort_inputs(inputs, kv_cache_ortvalues): # Check that all model inputs will be provided model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs())) user_inputs = set(inputs.keys()) @@ -467,29 +450,13 @@ def prepare_ort_inputs(inputs): # Add IO bindings for non-CPU execution providers if args.device != "cpu": - io_binding = model.io_binding() - - for k, v in inputs.items(): - if args.past_present_share_buffer: - # Bind all OrtValue inputs to device - io_binding.bind_ortvalue_input(k, v) - else: - io_binding.bind_cpu_input(k, v) - - for output in model.get_outputs(): - name = output.name - if args.past_present_share_buffer and ("out" in name or "present" in name): - # Bind present KV cache outputs to OrtValue with buffer sharing - io_binding.bind_ortvalue_output( - name, inputs[name.replace("out", "cache").replace("present", "past_key_values")] - ) - else: - io_binding.bind_output(name, device_type=args.device, device_id=args.device_id) - + io_binding, kv_cache_ortvalues = add_io_bindings( + model, inputs, args.device, int(args.device_id), kv_cache_ortvalues + ) setattr(args, "io_binding", io_binding) # noqa: B010 - return io_binding + return io_binding, kv_cache_ortvalues - return inputs + return inputs, kv_cache_ortvalues def with_io_binding(io_binding): # Inference pass with IO binding @@ -501,9 +468,10 @@ def without_io_binding(inputs): return outputs generate_fn = with_io_binding if args.device != "cpu" else without_io_binding + kv_cache_ortvalues = {} if args.profile: - ort_init_inputs = prepare_ort_inputs(init_inputs) + ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues) new_logname = profile_fn(args, generate_fn, ort_init_inputs, "prompt") # Turn profiling off to stop appending to log file @@ -513,7 +481,7 @@ def without_io_binding(inputs): # Re-initialize model for new log file instead of appending to old log file model = get_model(args) - ort_iter_inputs = prepare_ort_inputs(iter_inputs) + ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues) new_logname = profile_fn(args, generate_fn, ort_iter_inputs, "token") # Turn profiling off to stop appending to log @@ -524,12 +492,12 @@ def without_io_binding(inputs): # ORT evaluations logger.info("\nEvaluating `model(inputs)` step to get past_key_values") - ort_init_inputs = prepare_ort_inputs(init_inputs) + ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues) time_fn(args, generate_fn, ort_init_inputs) measure_fn(args, generate_fn, ort_init_inputs) logger.info("\nEvaluating `model(inputs)` step with past_key_values") - ort_iter_inputs = prepare_ort_inputs(iter_inputs) + ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues) time_fn(args, generate_fn, ort_iter_inputs) measure_fn(args, generate_fn, ort_iter_inputs) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 69603fd3ed488..3f05be53c6729 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -716,6 +716,7 @@ def main(): run_torchscript_separate_export(args, l_config, llama) else: run_torchscript_merged_export(args, l_config, llama) + del llama # Delete LLaMA model from memory since it will be loaded again during parity check # Set model paths to store FP32 optimized model decoder_model_fp32_opt_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32_opt.onnx") @@ -811,7 +812,6 @@ def main(): logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!") remove_existing_model(fp_path) - del llama # Delete LLaMA model from memory since it will be loaded again during parity check logger.info("Verifying parity on all ONNX models created") # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py index 2652e9f0ca64e..f7a1b05249abf 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py @@ -4,7 +4,7 @@ import torch from transformers import LlamaConfig -from onnxruntime import OrtValue +from onnxruntime import InferenceSession, OrtValue # Get position_ids from attention_mask @@ -12,22 +12,36 @@ def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if use_past_kv: + # Shape: (batch_size, 1) position_ids = position_ids[:, -1].unsqueeze(-1) + + # Shape: (batch_size, sequence_length) return position_ids # Inputs for first pass to get initial past_key_values +# input_ids: (batch_size, sequence_length) +# attention_mask: (batch_size, sequence_length) +# position_ids: (batch_size, sequence_length) def get_sample_inputs( - config: LlamaConfig, device: torch.device, batch_size: int, seq_len: int, return_dict: bool = False + config: LlamaConfig, + device: torch.device, + batch_size: int, + seq_len: int, + engine: str = "pt", + return_dict: bool = False, ): - input_ids = torch.randint( - low=0, high=config.vocab_size, size=(batch_size, seq_len), device=device, dtype=torch.int64 - ) - attention_mask = torch.ones(batch_size, seq_len, device=device, dtype=torch.int64) - # position_ids is of shape (batch_size, seq_len) + input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64) + attention_mask = torch.ones(batch_size, seq_len, dtype=torch.int64) position_ids = get_position_ids(attention_mask, use_past_kv=False) + # Convert inputs to NumPy (for ORT) or send to device (for PyTorch) + input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device) + attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device) + position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device) + if not return_dict: + # For export return (input_ids, attention_mask, position_ids) inputs = { @@ -39,85 +53,192 @@ def get_sample_inputs( # Inputs for subsequent passes with past_key_values +# input_ids: (batch_size, 1) +# attention_mask: (batch_size, past_sequence_length + 1) +# position_ids: (batch_size, 1) +# past_key: (batch_size, num_heads, past_sequence_length, head_size) +# past_value: (batch_size, num_heads, past_sequence_length, head_size) def get_sample_with_past_kv_inputs( config: LlamaConfig, device: torch.device, batch_size: int, past_seq_len: int, use_fp16: bool = False, + engine: str = "pt", return_dict: bool = False, ): - input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), device=device, dtype=torch.int64) - attention_mask = torch.ones(batch_size, past_seq_len + 1, device=device, dtype=torch.int64) + input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), dtype=torch.int64) + attention_mask = torch.ones(batch_size, past_seq_len + 1, dtype=torch.int64) # position_ids is of shape (batch_size, 1) position_ids = get_position_ids(attention_mask, use_past_kv=True) - past_kv = get_sample_past_kv_inputs(config, device, batch_size, past_seq_len, use_fp16) + past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16) + + # Convert inputs to NumPy (for ORT) or send to device (for PyTorch) + input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device) + attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device) + position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device) + past_kv = ( + flatten_past_kv_inputs(past_kv) + if engine == "ort" + else list(map(lambda kv: (kv[0].to(device), kv[1].to(device)), past_kv)) + ) if not return_dict: + # For export + assert isinstance(past_kv, list) return (input_ids, attention_mask, position_ids, past_kv) inputs = { "input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids, - "past_key_values": past_kv, } + if engine == "ort": + assert isinstance(past_kv, dict) + inputs.update(past_kv) + else: + assert isinstance(past_kv, list) + inputs["past_key_values"] = past_kv + return inputs # Inputs for all passes with past_key_values +# input_ids: (batch_size, sequence_length) +# attention_mask: (batch_size, past_sequence_length + sequence_length) +# position_ids: (batch_size, sequence_length) +# past_key: (batch_size, num_heads, kv_sequence_length, head_size) +# For models with GQA, kv_sequence_length = max_sequence_length +# For models without GQA, kv_sequence_length = past_sequence_length +# past_value: (batch_size, num_heads, kv_sequence_length, head_size) +# For models with GQA, kv_sequence_length = max_sequence_length +# For models without GQA, kv_sequence_length = past_sequence_length def get_merged_sample_with_past_kv_inputs( config: LlamaConfig, device: torch.device, batch_size: int, seq_len: int, past_seq_len: int, + max_seq_len: int, use_fp16: bool = False, + engine: str = "pt", return_dict: bool = False, ): - input_ids = torch.randint( - low=0, high=config.vocab_size, size=(batch_size, seq_len), device=device, dtype=torch.int64 - ) - attention_mask = torch.ones(batch_size, past_seq_len + seq_len, device=device, dtype=torch.int64) + input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64) + attention_mask = torch.ones(batch_size, past_seq_len + seq_len, dtype=torch.int64) # position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation position_ids = get_position_ids(attention_mask, use_past_kv=(past_seq_len != 0)) - past_kv = get_sample_past_kv_inputs(config, device, batch_size, past_seq_len, use_fp16) + past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16) + + # Convert inputs to NumPy (for ORT) or send to device (for PyTorch) + input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device) + attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device) + position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device) + past_kv = ( + flatten_past_kv_inputs(past_kv) + if engine == "ort" + else list(map(lambda kv: (kv[0].to(device), kv[1].to(device)), past_kv)) + ) if not return_dict: + # For export + assert isinstance(past_kv, list) return (input_ids, attention_mask, position_ids, past_kv) inputs = { "input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids, - "past_key_values": past_kv, } + if engine == "ort": + assert isinstance(past_kv, dict) + inputs.update(past_kv) + + if use_fp16: # If model has GQA + del inputs["attention_mask"] + inputs["past_sequence_length"] = np.array([past_seq_len], dtype=np.int64) + inputs = enable_past_present_share_buffer(inputs, past_seq_len, max_seq_len) + + else: + assert isinstance(past_kv, list) + inputs["past_key_values"] = past_kv + return inputs -# Create past_key_values -def get_sample_past_kv_inputs( - config: LlamaConfig, device: torch.device, batch_size: int, past_seq_len: int, use_fp16: bool +# Inputs for Microsoft export from https://github.com/microsoft/Llama-2-Onnx +def get_msft_sample_inputs( + config: LlamaConfig, + batch_size: int, + past_seq_len: int, + seq_len: int, + max_seq_len: int, + use_fp16: bool, + split_kv: bool, ): + np_dtype = np.float16 if use_fp16 else np.float32 + head_size = config.hidden_size // config.num_attention_heads + + if not split_kv: + ort_inputs = { + "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype), + "attn_mask": (-10000.0 * np.triu(np.ones((batch_size, max_seq_len, max_seq_len)), k=1)).astype(np_dtype), + "k_cache": np.random.rand( + batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size + ).astype(np_dtype), + "v_cache": np.random.rand( + batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size + ).astype(np_dtype), + "pos": np.array(past_seq_len, dtype=np.int64), + } + else: + ort_inputs = { + "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype), + "attn_mask": (np.triu(np.ones((batch_size, max_seq_len, max_seq_len), dtype=np.int32), k=1) - 1).astype( + np.int32 + ), + "pos": np.array(past_seq_len, dtype=np.int64), + } + for i in range(config.num_hidden_layers): + ort_inputs.update( + { + f"k_{i}_cache": np.random.rand( + batch_size, config.num_attention_heads, past_seq_len, head_size + ).astype(np_dtype), + f"v_{i}_cache": np.random.rand( + batch_size, config.num_attention_heads, past_seq_len, head_size + ).astype(np_dtype), + } + ) + + if use_fp16: # If model has GQA + del ort_inputs["attn_mask"] + ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len) + + return ort_inputs + + +# Create past_key_values +# Each is of shape (batch_size, num_heads, past_sequence_length, head_size) +def get_past_kv_inputs(config: LlamaConfig, batch_size: int, past_seq_len: int, use_fp16: bool): num_heads, head_size = config.num_key_value_heads, config.hidden_size // config.num_key_value_heads torch_dtype = torch.float16 if use_fp16 else torch.float32 past_kv = [ ( - torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype), - torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype), + torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype), + torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype), ) for _ in range(config.num_hidden_layers) ] return past_kv -# Convert list of past_kv to dict of past_key and past_value -def flatten_past_kv_inputs(past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], use_fp16: bool): +# Convert list of past_key_values to dict of past_key and past_value +def flatten_past_kv_inputs(past_key_values: List[Tuple[torch.Tensor, torch.Tensor]]): past_kv = {} - np_dtype = np.float16 if use_fp16 else np.float32 for i, (past_k, past_v) in enumerate(past_key_values): - past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy().astype(np_dtype) - past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy().astype(np_dtype) + past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy() + past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy() return past_kv @@ -136,7 +257,7 @@ def convert_inputs_for_ort( if isinstance(v, np.ndarray): ort_inputs[k] = v elif k == "past_key_values": - ort_inputs.update(flatten_past_kv_inputs(v, use_fp16)) + ort_inputs.update(flatten_past_kv_inputs(v)) elif k == "attention_mask" and use_fp16 and use_buffer_share: # Skip because FP16 model has GroupQueryAttention, uses buffer sharing, # and GQA supports a causal mask by default @@ -146,59 +267,55 @@ def convert_inputs_for_ort( else: ort_inputs[k] = v.detach().cpu().numpy() - # Enable past-present-share-buffer by using device memory directly + # Reshape kv caches if using past-present-share-buffer if use_buffer_share and device != "" and device != "cpu" and device_id > -1: - for k, v in ort_inputs.items(): - new_v = v - # Allocate new buffers with max_sequence_length for GQA - if "cache" in k or "past_key_values" in k: - # Copy v (BxSxPxH) into new_v (BxSxMxH) - batch_size, num_heads, _, head_size = v.shape - new_v = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=v.dtype) - new_v[:batch_size, :num_heads, :past_seq_len, :head_size] = v - ort_inputs[k] = OrtValue.ortvalue_from_numpy(new_v, device_type=device, device_id=device_id) + ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len) return ort_inputs -# Inputs for Microsoft export from https://github.com/microsoft/Llama-2-Onnx -def get_msft_sample_inputs( - config: LlamaConfig, batch_size: int, past_seq_len: int, seq_len: int, use_fp16: bool, split_kv: bool -): - np_dtype = np.float16 if use_fp16 else np.float32 - head_size = config.hidden_size // config.num_attention_heads - max_seq_len = 2048 +def enable_past_present_share_buffer(ort_inputs: dict, past_seq_len: int, max_seq_len: int): + for k, v in ort_inputs.items(): + # Allocate new buffers with max_sequence_length for GQA + if "cache" in k or "past_key_values" in k: + # Copy v (BxSxPxH) into new_v (BxSxMxH) + batch_size, num_heads, _, head_size = v.shape + new_v = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=v.dtype) + new_v[:batch_size, :num_heads, :past_seq_len, :head_size] = v + ort_inputs[k] = new_v + return ort_inputs - if not split_kv: - ort_inputs = { - "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype), - "attn_mask": (-10000.0 * np.triu(np.ones((batch_size, max_seq_len, max_seq_len)), k=1)).astype(np_dtype), - "k_cache": np.random.rand( - batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size - ).astype(np_dtype), - "v_cache": np.random.rand( - batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size - ).astype(np_dtype), - "pos": np.array(past_seq_len, dtype=np.int64), - } - else: - ort_inputs = { - "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype), - "attn_mask": (np.triu(np.ones((batch_size, max_seq_len, max_seq_len), dtype=np.int32), k=1) - 1).astype( - np.int32 - ), - "pos": np.array(past_seq_len, dtype=np.int64), - } - for i in range(config.num_hidden_layers): - ort_inputs.update( - { - f"k_{i}_cache": np.random.rand( - batch_size, config.num_attention_heads, past_seq_len, head_size - ).astype(np_dtype), - f"v_{i}_cache": np.random.rand( - batch_size, config.num_attention_heads, past_seq_len, head_size - ).astype(np_dtype), - } - ) - return ort_inputs +# Add IO bindings for execution providers +def add_io_bindings(model: InferenceSession, ort_inputs: dict, device: str, device_id: int, kv_cache_ortvalues: dict): + use_fp16 = False + io_binding = model.io_binding() + + for k, v in ort_inputs.items(): + # Detect if model is in FP16 + if v.dtype == np.float16: + use_fp16 = True + + # Bind OrtValue inputs to device + if use_fp16 and ("cache" in k or "past_key_values" in k): + if k not in kv_cache_ortvalues: + v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id) + io_binding.bind_ortvalue_input(k, v_device) + kv_cache_ortvalues[k] = v_device + else: + kv_cache_ortvalues[k].update_inplace(v) + io_binding.bind_ortvalue_input(k, kv_cache_ortvalues[k]) + else: + v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id) + io_binding.bind_ortvalue_input(k, v_device) + + for output in model.get_outputs(): + name = output.name + if use_fp16 and ("out" in name or "present" in name): + # Bind present KV cache outputs to past KV cache inputs in order to buffer share + input_name = name.replace("out", "cache").replace("present", "past_key_values") + io_binding.bind_ortvalue_output(name, kv_cache_ortvalues[input_name]) + else: + io_binding.bind_output(name, device_type=device, device_id=device_id) + + return io_binding, kv_cache_ortvalues diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index 4353d0606803d..c1c5d3c412f2a 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -8,6 +8,7 @@ import torch from benchmark_helper import setup_logger from llama_inputs import ( + add_io_bindings, convert_inputs_for_ort, get_merged_sample_with_past_kv_inputs, get_sample_inputs, @@ -22,22 +23,24 @@ def get_sequence_lengths(args: argparse.Namespace): past_sequence_length, curr_sequence_length = (8, 1) if args.use_past_kv else (0, 8) - max_sequence_length = 2048 + temp_name = args.model_name.lower().replace("-", "").replace("_", "") + max_sequence_length = 16384 if "codellama" in temp_name else 4096 if "llama2" in temp_name else 2048 return past_sequence_length, curr_sequence_length, max_sequence_length def get_inputs(args: argparse.Namespace, config: LlamaConfig): # Dummy values for parity batch_size = 2 - past_sequence_length, sequence_length, _ = get_sequence_lengths(args) + past_sequence_length, sequence_length, max_sequence_length = get_sequence_lengths(args) if args.merged: inputs = get_merged_sample_with_past_kv_inputs( config, args.device, batch_size, - sequence_length, - past_sequence_length, + seq_len=sequence_length, + past_seq_len=past_sequence_length, + max_seq_len=max_sequence_length, use_fp16=args.use_fp16, return_dict=True, ) @@ -51,31 +54,7 @@ def get_inputs(args: argparse.Namespace, config: LlamaConfig): return inputs -def add_io_bindings(args: argparse.Namespace, model: ort.InferenceSession, inputs: dict): - # Add IO bindings for non-CPU execution providers - io_binding = model.io_binding() - - for k, v in inputs.items(): - if args.use_fp16: - # Bind all OrtValue inputs to device - io_binding.bind_ortvalue_input(k, v) - else: - io_binding.bind_cpu_input(k, v) - - for output in model.get_outputs(): - name = output.name - if args.use_fp16 and ("out" in name or "present" in name): - # Bind present KV cache outputs to OrtValue with buffer sharing - io_binding.bind_ortvalue_output( - name, inputs[name.replace("out", "cache").replace("present", "past_key_values")] - ) - else: - io_binding.bind_output(name, device_type=args.execution_provider, device_id=int(args.device_id)) - - return io_binding - - -def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: LlamaForCausalLM): +def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: LlamaForCausalLM, kv_cache_ortvalues: dict): inputs = get_inputs(args, config) # Run inference with PyTorch @@ -111,7 +90,9 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama # Add IO bindings for non-CPU execution providers if args.execution_provider != "cpu": - io_binding = add_io_bindings(args, ort_model, inputs) + io_binding, kv_cache_ortvalues = add_io_bindings( + ort_model, inputs, args.execution_provider, int(args.device_id), kv_cache_ortvalues + ) io_binding.synchronize_inputs() start_time = time.time() @@ -131,17 +112,12 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama logger.info(f"ONNX Runtime took {end_time - start_time} s") # Compare PyTorch and ONNX Runtime accuracy - tol = ( - 2e1 - if "int4" in args.onnx_model_path or "int8" in args.onnx_model_path - else 1e-3 - if args.precision == "fp32" - else 5e-1 - ) + tol = 2e1 if "int4" in args.onnx_model_path or "int8" in args.onnx_model_path else 5e-1 parity = np.allclose(pt_outputs, ort_outputs, rtol=tol, atol=tol) logger.warning(f"Are PyTorch and ONNX Runtime results close? {parity}") if not parity: logger.warning(f"Max diff: {np.max(pt_outputs - ort_outputs)}") + return kv_cache_ortvalues def get_args(argv: List[str]): @@ -250,16 +226,17 @@ def main(argv: List[str] = []): # noqa: B006 use_cache=True, ).to(args.device) + kv_cache_ortvalues = {} if not args.merged: - verify_parity(args, config, llama) + verify_parity(args, config, llama, kv_cache_ortvalues) else: # Verify prompt generation in merged model (decoder_model.onnx) args.use_past_kv = False - verify_parity(args, config, llama) + kv_cache_ortvalues = verify_parity(args, config, llama, kv_cache_ortvalues) # Verify token generation in merged model (decoder_with_past_model.onnx) args.use_past_kv = True - verify_parity(args, config, llama) + verify_parity(args, config, llama, kv_cache_ortvalues) if __name__ == "__main__":