Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Granite to model builder #1153

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
22 changes: 10 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,28 +1,26 @@
# ONNX Runtime generate() API
# ONNX Runtime GenAI

## *Main branch contains new API changes and examples in main branch reflect these changes. For example scripts compatible with current release (0.5.2), [see release branch](https://github.com/microsoft/onnxruntime-genai/tree/rel-0.5.2).*


[![Latest version](https://img.shields.io/nuget/vpre/Microsoft.ML.OnnxRuntimeGenAI.Managed?label=latest)](https://www.nuget.org/packages/Microsoft.ML.OnnxRuntimeGenAI.Managed/absoluteLatest)

Run Llama, Phi, Gemma, Mistral with ONNX Runtime.
Run generative AI models with ONNX Runtime.

This API gives you an easy, flexible and performant way of running LLMs on device.

It implements the generative AI loop for ONNX models, including pre and post processing, inference with ONNX Runtime, logits processing, search and sampling, and KV cache management.

You can call a high level `generate()` method to generate all of the output at once, or stream the output one token at a time.

See documentation at https://onnxruntime.ai/docs/genai.

|Support matrix|Supported now|Under development|On the roadmap|
|-|-|-|-|
|Model architectures| Gemma <br/> Llama * <br/> Mistral + <br/>Phi (language + vision)<br/>Qwen <br/>Nemotron <br/>|Whisper|Stable diffusion|
|API| Python <br/>C# <br/>C/C++ <br/> Java ^ |Objective-C||
|Platform| Linux <br/> Windows <br/>Mac ^ <br/>Android ^ ||iOS |||
|Architecture|x86 <br/> x64 <br/> Arm64 ~ ||||
|Hardware Acceleration|CUDA<br/>DirectML<br/>|QNN <br/> OpenVINO <br/> ROCm ||
|Features|| Interactive decoding <br/> Customization (fine-tuning)| Speculative decoding |
| Support matrix | Supported now | Under development | On the roadmap |
| -------------- | ------------- | ----------------- | -------------- |
| Model architectures | Gemma <br/> Llama * <br/> Mistral + <br/> Phi (language + vision) <br/> Qwen <br/> Nemotron <br/> Granite <br/> | Whisper | Stable diffusion |
| API | Python <br/> C# <br/> C/C++ <br/> Java ^ | Objective-C | |
| Platform | Linux <br/> Windows <br/> Mac ^ <br/> Android ^ | | iOS |
| Architecture | x86 <br/> x64 <br/> Arm64 ~ | | |
| Hardware Acceleration | CUDA <br/> DirectML <br/> | QNN <br/> OpenVINO <br/> ROCm | |
| Features | | Interactive decoding <br/> Customization (fine-tuning) | Speculative decoding |

\* The Llama model architecture supports similar model families such as CodeLlama, Vicuna, Yi, and more.

Expand Down
1 change: 1 addition & 0 deletions src/python/py/models/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ The tool currently supports the following model architectures.

- ChatGLM
- Gemma
- Granite
- LLaMA
- Mistral
- Nemotron
Expand Down
106 changes: 79 additions & 27 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1714,21 +1714,45 @@
def make_mlp_proj(self, layer_id, mlp, root_input):
# Make nodes for the MLP subgraph
#
# root_input
# / \
# root_input
# / \
# / \
# UpProjMatMul GateProjMatMul
# \ |
# \ ActFunc
# \ /
# Mul
# |
# DownProjMatMul

# Make MatMul nodes
gate_basename = f"/model/layers.{layer_id}/mlp/gate_proj/MatMul"
gate_name = self.make_matmul(mlp.gate_proj, gate_basename, root_input)
up_basename = f"/model/layers.{layer_id}/mlp/up_proj/MatMul"
up_name = self.make_matmul(mlp.up_proj, up_basename, root_input)
# | |
# UpProjAdd GateProjAdd
# \ |
# \ ActFunc
# \ /
# \ /
# \ /
# Mul
# |
# DownProjMatMul
# |
# DownProjAdd

# Check if Add nodes need to be made (if bias exists)
gate_bias_exists = mlp.gate_proj.bias is not None and torch.count_nonzero(mlp.gate_proj.bias) > 0
up_bias_exists = mlp.up_proj.bias is not None and torch.count_nonzero(mlp.up_proj.bias) > 0
down_bias_exists = mlp.down_proj.bias is not None and torch.count_nonzero(mlp.down_proj.bias) > 0

# Make Gate proj nodes
gate_matmul_basename = f"/model/layers.{layer_id}/mlp/gate_proj/MatMul"
gate_matmul_name = self.make_matmul(mlp.gate_proj, gate_matmul_basename, root_input)
gate_name = gate_matmul_name
if gate_bias_exists:
gate_add_name = f"/model/layers.{layer_id}/mlp/gate_proj/Add"
self.make_add_bias(mlp.gate_proj.bias.detach().numpy(), gate_add_name, root_input=f"{gate_name}/output_0")
gate_name = gate_add_name

# Make Up proj nodes
up_matmul_basename = f"/model/layers.{layer_id}/mlp/up_proj/MatMul"
up_matmul_name = self.make_matmul(mlp.up_proj, up_matmul_basename, root_input)
up_name = up_matmul_name
if up_bias_exists:
up_add_name = f"/model/layers.{layer_id}/mlp/up_proj/Add"
self.make_add_bias(mlp.up_proj.bias.detach().numpy(), up_add_name, root_input=f"{up_name}/output_0")
up_name = up_add_name

# Make activation node(s)
act_fn_name = self.make_activation(layer_id, root_input=f"{gate_name}/output_0")
Expand All @@ -1740,8 +1764,13 @@

# Make output MatMul node
down_proj = getattr(mlp, "down_proj", None) or getattr(mlp, "dense_4h_to_h", None)
down_basename = f"/model/layers.{layer_id}/mlp/down_proj/MatMul"
down_name = self.make_matmul(down_proj, down_basename, f"{mul_name}/output_0")
down_matmul_basename = f"/model/layers.{layer_id}/mlp/down_proj/MatMul"
down_matmul_name = self.make_matmul(down_proj, down_matmul_basename, f"{mul_name}/output_0")
down_name = down_matmul_name
if down_bias_exists:
down_add_name = f"/model/layers.{layer_id}/mlp/down_proj/Add"
self.make_add_bias(mlp.down_proj.bias.detach().numpy(), down_add_name, root_input=f"{down_name}/output_0")
down_name = down_add_name

# Assign output 0 of previous MatMul as skip input to next SkipLayerNorm
self.layernorm_attrs["skip_input"] = f"{down_name}/output_0"
Expand All @@ -1761,23 +1790,33 @@
# |
# FC2_Add

# Check if Add nodes need to be made (if bias exists)
fc1_bias_exists = mlp.fc1.bias is not None and torch.count_nonzero(mlp.fc1.bias) > 0
fc2_bias_exists = mlp.fc2.bias is not None and torch.count_nonzero(mlp.fc2.bias) > 0

# Make first layer of fully connected nodes (FC1)
fc1_matmul_basename = f"/model/layers.{layer_id}/mlp/fc1/MatMul"
fc1_matmul_name = self.make_matmul(mlp.fc1, fc1_matmul_basename, root_input)
fc1_add_name = f"/model/layers.{layer_id}/mlp/fc1/Add"
self.make_add_bias(mlp.fc1.bias.detach().numpy(), fc1_add_name, root_input=f"{fc1_matmul_name}/output_0")
fc1_name = fc1_matmul_name
if fc1_bias_exists:
fc1_add_name = f"/model/layers.{layer_id}/mlp/fc1/Add"
self.make_add_bias(mlp.fc1.bias.detach().numpy(), fc1_add_name, root_input=f"{fc1_name}/output_0")
fc1_name = fc1_add_name

# Make activation function
act_fn_name = self.make_activation(layer_id, root_input=f"{fc1_add_name}/output_0")
act_fn_name = self.make_activation(layer_id, root_input=f"{fc1_name}/output_0")

# Make second layer of fully connected nodes (FC2)
fc2_matmul_basename = f"/model/layers.{layer_id}/mlp/fc2/MatMul"
fc2_matmul_name = self.make_matmul(mlp.fc2, fc2_matmul_basename, root_input=f"{act_fn_name}/output_0")
fc2_add_name = f"/model/layers.{layer_id}/mlp/fc2/Add"
self.make_add_bias(mlp.fc2.bias.detach().numpy(), fc2_add_name, root_input=f"{fc2_matmul_name}/output_0")
fc2_name = fc2_matmul_name
if fc2_bias_exists:
fc2_add_name = f"/model/layers.{layer_id}/mlp/fc2/Add"
self.make_add_bias(mlp.fc2.bias.detach().numpy(), fc2_add_name, root_input=f"{fc2_name}/output_0")
fc2_name = fc2_add_name

# Assign output 0 of MLP layer as output of last layer
self.mlp_attrs["output_0"] = f"{fc2_add_name}/output_0"
self.mlp_attrs["output_0"] = f"{fc2_name}/output_0"

def make_block_sparse_moe(self, layer_id, bsm, root_input):
# Make nodes for the QMoE subgraph
Expand Down Expand Up @@ -1968,35 +2007,40 @@
return output_name

def make_lm_head(self, lm_head):
# Check if there are ops to insert after MatMul
bias_exists = lm_head.bias is not None
scale_exists = self.lm_head_attrs["scale"] != 1
mask_exists = self.lm_head_attrs["mask"] is not None

matmul_basename = "/lm_head/MatMul"
root_input = self.layernorm_attrs["output_0"]
matmul_name = self.make_matmul(lm_head, matmul_basename, root_input, logits=not bias_exists and not scale_exists)
matmul_name = self.make_matmul(lm_head, matmul_basename, root_input, logits=not(bias_exists or scale_exists or mask_exists))
lm_name = matmul_name

if bias_exists:
add_name = "/lm_head/Add"
self.make_add_bias(lm_head.bias.detach().numpy(), add_name, root_input=f"{matmul_name}/output_0", logits=not scale_exists)
self.make_add_bias(lm_head.bias.detach().numpy(), add_name, root_input=f"{lm_name}/output_0", logits=not(scale_exists or mask_exists))
lm_name = add_name

if scale_exists:
mul_name = "/lm_head/Mul"
mul_inputs = [f"{matmul_name if not bias_exists else add_name}/output_0", f"/model/constants/{self.to_str_dtype[self.io_dtype]}/0D/{self.lm_head_attrs['scale']}"]
mul_inputs = [f"{lm_name}/output_0", f"/model/constants/{self.to_str_dtype[self.io_dtype]}/0D/{self.lm_head_attrs['scale']}"]
mul_output = "logits" if not mask_exists else f"{mul_name}/output_0"
self.make_node('Mul', inputs=mul_inputs, outputs=[mul_output], name=mul_name)
self.make_value_info(mul_output, self.io_dtype, shape=['batch_size', 'sequence_length', self.vocab_size])
lm_name = mul_name

if mask_exists:
# Save logits mask as initializer
logits_mask_name = "logits_mask"
self.make_external_tensor(self.lm_head_attrs["mask"].detach().numpy(), logits_mask_name)

where_name = "/lm_head/Where"
where_inputs = [logits_mask_name, f"/model/constants/{self.to_str_dtype[self.io_dtype]}/0D/{np.finfo(self.to_numpy_dtype[self.io_dtype]).min}", f"{mul_name}/output_0"]
where_inputs = [logits_mask_name, f"/model/constants/{self.to_str_dtype[self.io_dtype]}/0D/{np.finfo(self.to_numpy_dtype[self.io_dtype]).min}", f"{lm_name}/output_0"]
where_output = "logits"
self.make_node('Where', inputs=where_inputs, outputs=[where_output], name=where_name)
self.make_value_info(where_output, self.io_dtype, shape=['batch_size', 'sequence_length', self.vocab_size])
lm_name = where_name
Dismissed Show dismissed Hide dismissed

def make_layer(self, layer_id, layer):
# Each LLM decoder layer is typically defined as:
Expand Down Expand Up @@ -2552,7 +2596,6 @@
class PhiModel(Model):
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options)
# self.input_shapes["position_ids"] = [1] # Note: This is optional and only needed if you want position_ids to be an int instead of a 2D tensor
self.layernorm_attrs["simple"] = False
self.rotemb_attrs["num_heads"] = self.num_attn_heads
self.rotemb_attrs["rotary_embedding_dim"] = int(self.head_size * self.rotemb_attrs["partial_rotary_factor"])
Expand Down Expand Up @@ -3043,6 +3086,13 @@
super().make_layer(layer_id, layer)


class GraniteModel(MistralModel):
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options)
self.embed_attrs["scale"] = config.embedding_multiplier
self.lm_head_attrs["scale"] = 1 / config.logits_scaling


def check_extra_options(kv_pairs):
"""
Check key-value pairs and set values correctly
Expand Down Expand Up @@ -3131,6 +3181,8 @@
onnx_model = GemmaModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options)
elif config.architectures[0] == "Gemma2ForCausalLM":
onnx_model = Gemma2Model(config, io_dtype, precision, execution_provider, cache_dir, extra_options)
elif config.architectures[0] == "GraniteForCausalLM":
onnx_model = GraniteModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options)
elif config.architectures[0] == "LlamaForCausalLM":
onnx_model = LlamaModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options)
elif config.architectures[0] == "MistralForCausalLM":
Expand Down
3 changes: 2 additions & 1 deletion test/python/_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def run_subprocess(
def get_model_paths():
hf_paths = {
"phi-2": "microsoft/phi-2",
# "phi-3-mini": "microsoft/Phi-3-mini-128k-instruct",
"phi-3.5": "microsoft/Phi-3.5-mini-instruct",
"granite-3.0": "ibm-granite/granite-3.0-2b-instruct",
}

ci_data_path = os.path.join("/", "data", "ortgenai_pytorch_models")
Expand Down
Loading