Skip to content

Commit

Permalink
Mistral Optimization & Benchmarking Support (#18225)
Browse files Browse the repository at this point in the history
### Description
As a prerequisite for this model running correctly, two PRs need to be
merged:

- GQA Sliding Window Attention:
https://github.com/microsoft/onnxruntime/tree/aciddelgado/gqa_local
- MHA Fusion:
https://github.com/frankdongms/onnxruntime/tree/frdong/llama_70b

This PR adds optimization, quantization, and benchmarking support for
Mistral. The README included describes steps to export, optimize, and
benchmark Mistral models, but won't function correctly without the two
above branches being merged first.

---------

Co-authored-by: Peter McAughan <[email protected]>
Co-authored-by: Abhishek Jindal <[email protected]>
Co-authored-by: kunal-vaishnavi <[email protected]>
  • Loading branch information
4 people authored Dec 5, 2023
1 parent c9e558c commit 871c529
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 7 deletions.
4 changes: 3 additions & 1 deletion onnxruntime/python/tools/transformers/convert_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1272,7 +1272,9 @@ def find_past_seq_len_usage(subg: GraphProto):
return tensor_names_to_rename, nodes_to_remove


def replace_mha_with_gqa(model: OnnxModel, attn_mask: str, kv_num_heads: int = 0, world_size: int = 1):
def replace_mha_with_gqa(
model: OnnxModel, attn_mask: str, kv_num_heads: int = 0, world_size: int = 1, window_size: int = 0
):
# Insert attention_mask subgraph to calculate shared inputs for all GroupQueryAttention nodes
#
# attention_mask
Expand Down
65 changes: 65 additions & 0 deletions onnxruntime/python/tools/transformers/models/llama/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
# Contents
- [LLaMA-2](#llama-2)
- [Exporting LLaMA-2](#exporting-llama-2)
- [Benchmarking LLaMA-2](#benchmark-llama-2)
- [Mistral](#mistral)
- [Exporting Mistral](#exporting-mistral)
- [Optimizing and Quantizing Mistral](#optimizing-and-quantizing-mistral)
- [Benchmarking Mistral](#benchmark-mistral)


# LLaMA-2

## Prerequisites
Expand Down Expand Up @@ -372,3 +382,58 @@ python3 -m models.llama.benchmark_all \
--num-runs 1000 \
--timeout 60 # number of minutes before moving to the next benchmark
```

# Mistral

## Introduction

These tools for LLaMA-2 also allow the quantization and optimization of Mistral in ORT.

## Exporting Mistral

There is currently one supported way to export Mistral to ONNX format:

### [Hugging Face Optimum](https://github.com/huggingface/optimum)


The following command will export Mistral in full precision:
```
python -m optimum.exporters.onnx -m mistralai/Mistral-7B-v0.1 --library-name transformers /path/to/model/directory
```

## Optimizing and Quantizing Mistral

To quantize Mistral to FP16 and apply fusion optimizations, you can run the following command:
```
python -m models.llama.convert_to_onnx -i /path/to/model/directory -o /path/to/optimized_model/directory -p fp16 --optimize_optimum -m mistralai/Mistral-7B-v0.1
```

## Benchmark Mistral
The benchmarking scripts in the LLaMA directory support Mistral benchmarking. To benchmark the ORT version, you can run:

```
python -m models.llama.benchmark \
-bt ort-convert-to-onnx \
-p fp16 \
-m mistralai/Mistral-7B-v0.1 \
--ort-model-path /path/to/model.onnx
```

To benchmark the Hugging Face implementation without `torch.compile`:

```
python -m models.llama.benchmark \
-bt hf-pt-eager \
-p fp16 \
-m mistralai/Mistral-7B-v0.1
```

And to benchmark the Hugging Face implementation with `torch.compile`:

```
python -m models.llama.benchmark \
-bt hf-pt-compile \
-p fp16 \
-m mistralai/Mistral-7B-v0.1
```

10 changes: 8 additions & 2 deletions onnxruntime/python/tools/transformers/models/llama/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
return_dict=True,
)

elif args.benchmark_type == "hf-ort":
elif args.benchmark_type in {"hf-ort"}:
if ort_model_inputs_len == 3: # [input_ids, attention_mask, position_ids]
# Using split models in Optimum (e.g. created by Optimum export)
init_inputs = get_sample_inputs(
Expand Down Expand Up @@ -529,7 +529,13 @@ def get_args(rank=0):
"--benchmark-type",
type=str,
required=True,
choices=["hf-pt-eager", "hf-pt-compile", "hf-ort", "ort-msft", "ort-convert-to-onnx"],
choices=[
"hf-pt-eager",
"hf-pt-compile",
"hf-ort",
"ort-msft",
"ort-convert-to-onnx",
],
)
parser.add_argument(
"-m",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def run_torchscript_merged_export(


# Optimize the model as FP32
def optimize_export(config: AutoConfig, input_path: str, output_path: str):
def optimize_export(config: AutoConfig, input_path: str, output_path: str, remove_model: bool = True):
from fusion_options import FusionOptions

optimization_options = FusionOptions("gpt2")
Expand All @@ -407,7 +407,8 @@ def optimize_export(config: AutoConfig, input_path: str, output_path: str):
)
model_opt.save_model_to_file(output_path, use_external_data_format=True)
logger.info(f"The ONNX model at {input_path} has been successfully optimized and saved at {output_path}!")
remove_existing_model(input_path)
if remove_model:
remove_existing_model(input_path)


def convert_to_float16(
Expand Down Expand Up @@ -438,7 +439,7 @@ def convert_to_float16(
return new_paths


def use_group_query_attention(config: AutoConfig, fp16_model_opt: OnnxModel, world_size: int = 1):
def use_group_query_attention(config: AutoConfig, fp16_model_opt: OnnxModel, world_size: int = 1, window_size: int = 0):
# Replace MultiHeadAttention with GroupQueryAttention
fp16_model_opt = replace_mha_with_gqa(fp16_model_opt, "attention_mask", config.num_key_value_heads, world_size)
fp16_model_opt.prune_graph()
Expand Down Expand Up @@ -539,6 +540,23 @@ def remove_existing_files(output_path: str):
logger.warning(f"Removed {filepath}")


def optimize_optimum(config: AutoConfig, args: argparse.Namespace):
tmp_file = os.path.join(args.output, args.model_name + ".tmp.onnx")
output_file = os.path.join(args.output, args.model_name + ".onnx")
optimize_export(config, args.input, tmp_file, remove_model=False)
logger.info(f"Model successfully optimized to {tmp_file}")
opt_model = OnnxModel(onnx.load_model(tmp_file, load_external_data=True))
if args.precision == Precision.FLOAT16:
opt_model.convert_float_to_float16(keep_io_types=False)
window_size = 0 if not hasattr(config, "sliding_window") else config.sliding_window
opt_model = use_group_query_attention(config, opt_model, args.world_size, window_size)
logger.info("Model successfully fused and quantized to FP16!")
opt_model.save_model_to_file(output_file, use_external_data_format=True)
logger.info(f"Output model successfully saved to {output_file}")
logger.info(f"Removing {tmp_file}")
remove_existing_model(tmp_file)


def get_args():
parser = argparse.ArgumentParser()

Expand All @@ -554,7 +572,7 @@ def get_args():
"--input",
required=False,
default=os.path.join("."),
help="Directory path to PyTorch model and associated files if saved on disk",
help="Directory path to PyTorch model and associated files if saved on disk, or ONNX model file location if optimize_optimum is passed.",
)

parser.add_argument(
Expand Down Expand Up @@ -720,6 +738,13 @@ def get_args():
help="model cache dir to override default HF cache dir to avoid overflood the /home dir",
)

parser.add_argument(
"--optimize_optimum",
action="store_true",
help="Avoid exporting model, only apply quantizations and optimizations to existing model exported from optimum.",
)
parser.set_defaults(optimize_optimum=False)

args = parser.parse_args()
return args

Expand All @@ -740,6 +765,7 @@ def main():

world_size = get_size()
rank = get_rank()
args.world_size = world_size

# Load model and config
use_auth_token = args.input == os.path.join(".")
Expand All @@ -754,6 +780,11 @@ def main():

location = args.original_model_name if use_auth_token else args.input

if args.optimize_optimum:
config = AutoConfig.from_pretrained(args.original_model_name)
optimize_optimum(config, args)
return

# Use CUDA for LLaMA-2-70B to speed up export and CPU for other models
l_config, llama = setup_torch_model(
args, location, use_auth_token, device=args.device if args.model_name == "Llama-2-70b-hf" else None
Expand Down

0 comments on commit 871c529

Please sign in to comment.