Skip to content

Commit

Permalink
Add Gemma to model builder tool (#113)
Browse files Browse the repository at this point in the history
### Description

This PR adds [Google's
Gemma](https://blog.google/technology/developers/gemma-open-models/)
models to the model builder tool.

### Motivation and Context

Google's Gemma is a family of foundation models that shares components
with [Google's
Gemini](https://blog.google/technology/ai/google-gemini-ai/) models.
  • Loading branch information
kunal-vaishnavi authored Feb 23, 2024
1 parent 0a6a2f9 commit fe24e9d
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 44 deletions.
29 changes: 15 additions & 14 deletions src/python/py/models/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ This folder contains the model builder tool, which greatly accelerates creating
## Current Support
The tool currently supports the following model architectures.

- Gemma
- LLaMA
- Mistral
- Phi
Expand All @@ -15,50 +16,50 @@ The tool currently supports the following model architectures.
For all available options, please use the `-h/--help` flag.
```
# From wheel:
python -m onnxruntime_genai.models.builder --help
python3 -m onnxruntime_genai.models.builder --help
# From source:
python builder.py --help
python3 builder.py --help
```

### Original Model From Hugging Face
This scenario is where your PyTorch model is not downloaded locally (either in the default Hugging Face cache directory or in a local folder on disk).
```
# From wheel:
python -m onnxruntime_genai.models.builder -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_to_save_hf_files
python3 -m onnxruntime_genai.models.builder -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_to_save_hf_files
# From source:
python builder.py -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_to_save_hf_files
python3 builder.py -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_to_save_hf_files
```

### Original Model From Disk
This scenario is where your PyTorch model is already downloaded locally (either in the default Hugging Face cache directory or in a local folder on disk).
```
# From wheel:
python -m onnxruntime_genai.models.builder -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_where_hf_files_are_saved
python3 -m onnxruntime_genai.models.builder -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_where_hf_files_are_saved
# From source:
python builder.py -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_where_hf_files_are_saved
python3 builder.py -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_where_hf_files_are_saved
```

### Customized or Finetuned Model
This scenario is where your PyTorch model has been customized or finetuned for one of the currently supported model architectures and your model can be loaded in Hugging Face.
```
# From wheel:
python -m onnxruntime_genai.models.builder -m path_to_local_folder_on_disk -o /path/to/output/folder -p precision -e execution_provider
python3 -m onnxruntime_genai.models.builder -m path_to_local_folder_on_disk -o /path/to/output/folder -p precision -e execution_provider
# From source:
python builder.py -m path_to_local_folder_on_disk -o /path/to/output/folder -p precision -e execution_provider
python3 builder.py -m path_to_local_folder_on_disk -o /path/to/output/folder -p precision -e execution_provider
```

### Extra Options
This scenario is for when you want to have control over some specific settings. The below example shows how you can pass key-value arguments to `--extra_options`.
```
# From wheel:
python -m onnxruntime_genai.models.builder -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_to_save_hf_files --extra_options filename=decoder.onnx
python3 -m onnxruntime_genai.models.builder -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_to_save_hf_files --extra_options filename=decoder.onnx
# From source:
python builder.py -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_to_save_hf_files --extra_options filename=decoder.onnx
python3 builder.py -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_to_save_hf_files --extra_options filename=decoder.onnx
```
To see all available options through `--extra_options`, please use the `help` commands in the `Full Usage` section above.

Expand All @@ -82,10 +83,10 @@ tokenizer.save_pretrained(cache_dir)
This option is the simplest but it will download another copy of the PyTorch model onto disk to accommodate the change in the number of hidden layers.
```
# From wheel:
python -m onnxruntime_genai.models.builder -m model_name -o /path/to/output/folder -p precision -e execution_provider --extra_options num_hidden_layers=4
python3 -m onnxruntime_genai.models.builder -m model_name -o /path/to/output/folder -p precision -e execution_provider --extra_options num_hidden_layers=4
# From source:
python builder.py -m model_name -o /path/to/output/folder -p precision -e execution_provider --extra_options num_hidden_layers=4
python3 builder.py -m model_name -o /path/to/output/folder -p precision -e execution_provider --extra_options num_hidden_layers=4
```

#### Option 2: Edit the config.json file on disk and then run the model builder tool
Expand All @@ -96,8 +97,8 @@ python builder.py -m model_name -o /path/to/output/folder -p precision -e execut

```
# From wheel:
python -m onnxruntime_genai.models.builder -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_where_hf_files_are_saved
python3 -m onnxruntime_genai.models.builder -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_where_hf_files_are_saved
# From source:
python builder.py -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_where_hf_files_are_saved
python3 builder.py -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_where_hf_files_are_saved
```
101 changes: 71 additions & 30 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
"total_seq_len": "", # Size of total sequence length in attention mask (used as input to GroupQueryAttention)
}

# Embedding-specific variables
self.embed_attrs = {
"normalize": False, # Normalize output of Embedding layer
}

# LayerNorm-specific variables
self.layernorm_attrs = {
"simple": True, # Use SimplifiedLayerNorm/SkipSimplifiedLayerNorm vs. LayerNorm/SkipLayerNorm
Expand All @@ -91,13 +96,15 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
"skip_input": "", # Skip input from parent node for SkipLayerNorm
"output_0": "", # Output 0 for LayerNorm and SkipLayerNorm
"output_3": "", # Output 3 for SkipLayerNorm
"add_offset": 0, # Offset value for LayerNorm weight
}

# RotaryEmbedding-specific variables
self.rotemb_attrs = {
"create_rotary_embedding_caches": True, # Create cos/sin caches for rotary embeddings
"num_heads": 0, # For partial rotary embeddings (RotaryEmbedding kernel expects a default value of 0)
"rotary_embedding_dim": 0, # For partial rotary embeddings (RotaryEmbedding kernel expects a default value of 0)
"theta": config.rope_theta, # Base value if calculating cos/sin caches from scratch
}

# Attention-specific variables (MHA, GQA, GQA + Rot.Emb., etc.)
Expand Down Expand Up @@ -151,14 +158,13 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir):
}

print(f"Saving GenAI config in {out_dir}")
with open(os.path.join(os.path.dirname(out_dir),"genai_config.json"), "w") as f:
with open(os.path.join(out_dir,"genai_config.json"), "w") as f:
json.dump(genai_config, f, indent=4)

def save_processing(self, model_name_or_path, extra_kwargs, out_dir):
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, **extra_kwargs)
dest_dir = os.path.dirname(out_dir)
print(f"Saving processing files in {dest_dir} for GenAI")
tokenizer.save_pretrained(dest_dir)
print(f"Saving processing files in {out_dir} for GenAI")
tokenizer.save_pretrained(out_dir)

def save_model(self, out_dir):
print(f"Saving ONNX model in {out_dir}")
Expand Down Expand Up @@ -198,8 +204,8 @@ def save_model(self, out_dir):
model = self.to_int4(model)

# Save ONNX model with only one external data file and delete any existing duplicate copies
out_path = os.path.join(os.path.dirname(out_dir), self.filename)
data_path = os.path.join(os.path.dirname(out_dir), os.path.basename(out_path) + ".data")
out_path = os.path.join(out_dir, self.filename)
data_path = os.path.join(out_dir, os.path.basename(out_path) + ".data")
if os.path.exists(out_path):
print(f"Overwriting {out_path}")
os.remove(out_path)
Expand Down Expand Up @@ -459,20 +465,34 @@ def make_embedding(self, embedding):
weight = "model.embed_tokens.weight"
self.make_external_tensor(embedding.astype(self.to_numpy_dtype[self.io_dtype]), weight)

name = "/model/embed_tokens/Gather"
output = f"{name}/output_0"
self.make_node('Gather', inputs=[weight, 'input_ids'], outputs=[output], name=name)
self.make_value_info(output, self.io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size])
basename = "/model/embed_tokens"
gather_name = f"{basename}/Gather"
gather_output = f"{gather_name}/output_0"
self.make_node('Gather', inputs=[weight, 'input_ids'], outputs=[gather_output], name=gather_name)
self.make_value_info(gather_output, self.io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size])

if self.embed_attrs["normalize"]:
# Normalize the embeddings
norm_name = f"{basename}/Mul"
norm_inputs = [gather_output, f"/model/constants/{self.to_str_dtype[self.io_dtype]}/0D/{np.round(np.sqrt(self.hidden_size), decimals=2)}"]
norm_output = f"{norm_name}/output_0"
self.make_node('Mul', inputs=norm_inputs, outputs=[norm_output], name=norm_name)
self.make_value_info(norm_output, self.io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size])

layernorm_attrs_value = norm_output
else:
layernorm_attrs_value = gather_output

self.layernorm_attrs["root_input"] = layernorm_attrs_value
self.layernorm_attrs["skip_input"] = layernorm_attrs_value

self.layernorm_attrs["root_input"] = output
self.layernorm_attrs["skip_input"] = output

def make_layernorm(self, layer_id, layernorm, skip, simple, location):
root_input = self.layernorm_attrs["root_input"]
skip_input = self.layernorm_attrs["skip_input"]

weight = f"model.layers.{layer_id}.{location}_layernorm.weight"
self.make_external_tensor(layernorm.weight.detach().numpy().astype(self.to_numpy_dtype[self.io_dtype]), weight)
self.make_external_tensor(layernorm.weight.detach().numpy().astype(self.to_numpy_dtype[self.io_dtype]) + self.layernorm_attrs["add_offset"], weight)
bias = f"model.layers.{layer_id}.{location}_layernorm.bias"
if not simple:
self.make_external_tensor(layernorm.bias.detach().numpy().astype(self.to_numpy_dtype[self.io_dtype]), bias)
Expand Down Expand Up @@ -510,11 +530,21 @@ def make_rotary_embedding(self, rotemb, name, root_input, **kwargs):
cos_cache_name, sin_cache_name = "cos_cache", "sin_cache"

if self.rotemb_attrs["create_rotary_embedding_caches"]:
if not hasattr(rotemb, "cos_cached"):
# Create cos/sin caches if not already created
inv_freq = 1.0 / (self.rotemb_attrs["theta"] ** (torch.arange(0, self.head_size, 2, dtype=torch.int64).float() / self.head_size))
t = torch.arange(self.context_length, dtype=torch.int64).type_as(inv_freq)
freqs = torch.outer(t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
cos_cache, sin_cache = emb.cos(), emb.sin()
else:
cos_cache, sin_cache = rotemb.cos_cached, rotemb.sin_cached

# Reshape cos/sin cache from (M, H) to (M, H/2)
hidden_dim = rotemb.cos_cached.shape[-1]
cos_cache = rotemb.cos_cached.squeeze()[:, : (hidden_dim // 2)].detach().numpy()
hidden_dim = cos_cache.shape[-1]
cos_cache = cos_cache.squeeze()[:, : (hidden_dim // 2)].detach().numpy()
self.make_external_tensor(cos_cache.astype(self.to_numpy_dtype[self.io_dtype]), cos_cache_name)
sin_cache = rotemb.sin_cached.squeeze()[:, : (hidden_dim // 2)].detach().numpy()
sin_cache = sin_cache.squeeze()[:, : (hidden_dim // 2)].detach().numpy()
self.make_external_tensor(sin_cache.astype(self.to_numpy_dtype[self.io_dtype]), sin_cache_name)

self.rotemb_attrs["create_rotary_embedding_caches"] = False
Expand Down Expand Up @@ -681,7 +711,7 @@ def make_mlp(self, layer_id, mlp, root_input):
# Assign output 0 of previous MatMul as skip input to next SkipLayerNorm
self.layernorm_attrs["skip_input"] = f"{down_name}/output_0"

def make_activation_with_mul(self, layer_id, root_input, activation):
def make_activation_with_mul(self, layer_id, root_input, activation, domain):
# Make nodes for this activation subgraph
#
# root_input (GateProjMatMul)
Expand All @@ -691,7 +721,7 @@ def make_activation_with_mul(self, layer_id, root_input, activation):
# Mul
act_name = f"/model/layers.{layer_id}/mlp/act_fn/{activation}"
act_output = f"{act_name}/output_0"
self.make_node(activation, inputs=[root_input], outputs=[act_output], name=act_name)
self.make_node(activation, inputs=[root_input], outputs=[act_output], name=act_name, domain=domain)
self.make_value_info(act_output, dtype=self.io_dtype, shape=["batch_size", "sequence_length", self.intermediate_size])

mul_act_name = f"/model/layers.{layer_id}/mlp/act_fn/Mul"
Expand All @@ -700,26 +730,26 @@ def make_activation_with_mul(self, layer_id, root_input, activation):

return mul_act_name

def make_fast_gelu(self, layer_id, root_input):
def make_gelu(self, layer_id, root_input, activation):
# Make nodes for this activation subgraph
#
# root_input (Add)
# |
# FastGelu
fast_gelu_name = f"/model/layers.{layer_id}/mlp/act_fn/FastGelu"
output = f"{fast_gelu_name}/output_0"
self.make_node("FastGelu", inputs=[root_input], outputs=[output], name=fast_gelu_name, domain="com.microsoft")
# GeluAct
gelu_name = f"/model/layers.{layer_id}/mlp/act_fn/{activation}"
output = f"{gelu_name}/output_0"
self.make_node(activation, inputs=[root_input], outputs=[output], name=gelu_name, domain="com.microsoft")
self.make_value_info(output, self.io_dtype, shape=['batch_size', 'sequence_length', self.intermediate_size])

return fast_gelu_name
return gelu_name

def make_activation(self, layer_id, root_input):
if self.activation in {"silu", "swish"}:
output_name = self.make_activation_with_mul(layer_id, root_input, "Sigmoid")
output_name = self.make_activation_with_mul(layer_id, root_input, activation="Sigmoid", domain=None)
elif self.activation in {"gelu_new", "gelu_fast"}:
output_name = self.make_fast_gelu(layer_id, root_input)
output_name = self.make_gelu(layer_id, root_input, activation="FastGelu")
elif self.activation in {"gelu"}:
output_name = self.make_activation_with_mul(layer_id, root_input, "Gelu")
output_name = self.make_gelu(layer_id, root_input, activation="Gelu")
else:
raise NotImplementedError(f"The {self.activation} activation function is not currently supported.")
return output_name
Expand Down Expand Up @@ -1298,6 +1328,13 @@ def make_layer(self, layer_id, layer):
self.layernorm_attrs["skip_input"] = f"{residual_add_name}/output_0"


class GemmaModel(MistralModel):
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options)
self.embed_attrs["normalize"] = True
self.layernorm_attrs["add_offset"] = 1


def parse_extra_options(kv_items):
"""
Parse key value pairs that are separated by '='
Expand All @@ -1314,14 +1351,18 @@ def parse_extra_options(kv_items):


def create_model(model_name_or_path, output_dir, precision, execution_provider, cache_dir, **extra_options):
os.makedirs(output_dir, exist_ok=True)
os.makedirs(cache_dir, exist_ok=True)
extra_kwargs = {} if os.path.exists(model_name_or_path) else {"cache_dir": cache_dir, "use_auth_token": True}
config = AutoConfig.from_pretrained(model_name_or_path, **extra_kwargs)

# Set input/output precision of ONNX model
io_dtype = TensorProto.FLOAT if precision in {"int8", "fp32"} or (precision == "int4" and execution_provider == "cpu") else TensorProto.FLOAT16

if config.architectures[0] == "LlamaForCausalLM":
# List architecture options in alphabetical order
if config.architectures[0] == "GemmaForCausalLM":
onnx_model = GemmaModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options)
elif config.architectures[0] == "LlamaForCausalLM":
onnx_model = LlamaModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options)
elif config.architectures[0] == "MistralForCausalLM":
onnx_model = MistralModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options)
Expand Down Expand Up @@ -1357,7 +1398,7 @@ def get_args():
"-o",
"--output",
required=True,
help="Path to folder containing ONNX model and additional files (e.g. GenAI config, external data files, etc.)",
help="Path to folder to store ONNX model and additional files (e.g. GenAI config, external data files, etc.)",
)

parser.add_argument(
Expand All @@ -1381,7 +1422,7 @@ def get_args():
"--cache_dir",
required=False,
type=str,
default=os.path.join(os.path.dirname(os.path.realpath(__file__)), 'cache_dir'),
default=os.path.join('.', 'cache_dir'),
help="Model cache directory (if providing model name and not folder path)",
)

Expand Down

0 comments on commit fe24e9d

Please sign in to comment.