Skip to content

Commit

Permalink
Add support for generating ONNX Runtime GenAI files only in model bui…
Browse files Browse the repository at this point in the history
…lder (#186)

### Description

This PR adds support for generating the config and tokenizer files that
are needed for ONNX Runtime GenAI.

### Motivation and Context

This is for when users already have their optimized and/or quantized
ONNX models. By generating the config and tokenizer files only, users
can run their own models with ONNX Runtime GenAI.
  • Loading branch information
kunal-vaishnavi authored Mar 12, 2024
1 parent a73e40c commit 5a2ca38
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 19 deletions.
13 changes: 13 additions & 0 deletions src/python/py/models/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ This folder contains the model builder tool, which greatly accelerates creating
- [Customized or Finetuned PyTorch Model](#customized-or-finetuned-pytorch-model)
- [GGUF Model](#gguf-model)
- [Extra Options](#extra-options)
- [Config Only](#config-only)
- [Unit Testing Models](#unit-testing-models)
- [Option 1: Use the model builder tool directly](#option-1-use-the-model-builder-tool-directly)
- [Option 2: Edit the config.json file](#option-2-edit-the-configjson-file-on-disk-and-then-run-the-model-builder-tool)
Expand Down Expand Up @@ -86,6 +87,18 @@ python3 builder.py -m model_name -o path_to_output_folder -p precision -e execut
```
To see all available options through `--extra_options`, please use the `help` commands in the `Full Usage` section above.

### Config Only
This scenario is for when you already have your optimized and/or quantized ONNX model and you need to create the config files to run with ONNX Runtime GenAI.
```
# From wheel:
python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_folder -p precision -e execution_provider -c cache_dir_for_hf_files --extra_options config_only=true
# From source:
python3 builder.py -m model_name -o path_to_output_folder -p precision -e execution_provider -c cache_dir_for_hf_files --extra_options config_only=true
```

Afterwards, please open the `genai_config.json` file in the output folder and modify the fields as needed for your model. You should store your ONNX model in the output folder as well.

### Unit Testing Models
This scenario is where your PyTorch model is already downloaded locally (either in the default Hugging Face cache directory or in a local folder on disk). If it is not already downloaded locally, here is an example of how you can download it.

Expand Down
44 changes: 25 additions & 19 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
self.window_size = config.sliding_window if hasattr(config, "sliding_window") else -1 # default is -1 in GroupQueryAttention kernel
self.intermediate_size = config.intermediate_size
self.hidden_size = config.hidden_size
self.num_kv_heads = config.num_key_value_heads
self.num_kv_heads = config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads
self.num_attn_heads = config.num_attention_heads
self.head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
self.num_layers = int(extra_options["num_hidden_layers"]) if "num_hidden_layers" in extra_options else config.num_hidden_layers
Expand Down Expand Up @@ -102,12 +102,13 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):

# RotaryEmbedding-specific variables
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
rope_theta = config.rope_theta if hasattr(config, "rope_theta") else 10000
self.rotemb_attrs = {
"create_rotary_embedding_caches": True, # Create cos/sin caches for rotary embeddings
"partial_rotary_factor": partial_rotary_factor, # Factor for partial rotary embeddings
"num_heads": 0, # For partial rotary embeddings (RotaryEmbedding kernel expects a default value of 0)
"rotary_embedding_dim": 0, # For partial rotary embeddings (RotaryEmbedding kernel expects a default value of 0)
"theta": config.rope_theta, # Base value if calculating cos/sin caches from scratch
"theta": rope_theta, # Base value if calculating cos/sin caches from scratch
}

# Attention-specific variables (MHA, GQA, GQA + Rot.Emb., etc.)
Expand Down Expand Up @@ -139,8 +140,8 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir):
"context_length": self.context_length,
"decoder": {
"session_options" : {
"provider_options" : [
]
"log_id": "onnxruntime-genai",
"provider_options" : []
},
"filename": self.filename,
"head_size": self.head_size,
Expand Down Expand Up @@ -1400,23 +1401,26 @@ def create_model(model_name, input_path, output_dir, precision, execution_provid
# Set input/output precision of ONNX model
io_dtype = TensorProto.FLOAT if precision in {"int8", "fp32"} or (precision == "int4" and execution_provider == "cpu") else TensorProto.FLOAT16

# List architecture options in alphabetical order
if config.architectures[0] == "GemmaForCausalLM":
onnx_model = GemmaModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options)
elif config.architectures[0] == "LlamaForCausalLM":
onnx_model = LlamaModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options)
elif config.architectures[0] == "MistralForCausalLM":
onnx_model = MistralModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options)
elif config.architectures[0] == "PhiForCausalLM":
onnx_model = PhiModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options)
else:
raise NotImplementedError(f"The {hf_name} model is not currently supported.")
if "config_only" not in extra_options:
# List architecture options in alphabetical order
if config.architectures[0] == "GemmaForCausalLM":
onnx_model = GemmaModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options)
elif config.architectures[0] == "LlamaForCausalLM":
onnx_model = LlamaModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options)
elif config.architectures[0] == "MistralForCausalLM":
onnx_model = MistralModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options)
elif config.architectures[0] == "PhiForCausalLM":
onnx_model = PhiModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options)
else:
raise NotImplementedError(f"The {hf_name} model is not currently supported.")

# Make ONNX model
onnx_model.make_model(input_path)
# Make ONNX model
onnx_model.make_model(input_path)

# Save ONNX model
onnx_model.save_model(output_dir)
# Save ONNX model
onnx_model.save_model(output_dir)
else:
onnx_model = Model(config, io_dtype, precision, execution_provider, cache_dir, extra_options)

# Make GenAI config
onnx_model.make_genai_config(hf_name, extra_kwargs, output_dir)
Expand Down Expand Up @@ -1497,6 +1501,8 @@ def get_args():
filename = Filename for ONNX model (default is 'model.onnx').
For models with multiple components, each component is exported to its own ONNX model.
The filename for each component will be '<filename>_<component-name>.onnx' (ex: '<filename>_encoder.onnx', '<filename>_decoder.onnx').
config_only = Generate config and pre/post processing files only.
Use this option when you already have your optimized and/or quantized ONNX model.
"""),
)

Expand Down

0 comments on commit 5a2ca38

Please sign in to comment.