diff --git a/src/python/py/models/README.md b/src/python/py/models/README.md index f9945d443..4dc30e7e2 100644 --- a/src/python/py/models/README.md +++ b/src/python/py/models/README.md @@ -5,6 +5,7 @@ This folder contains the model builder tool, which greatly accelerates creating ## Current Support The tool currently supports the following model architectures. +- Gemma - LLaMA - Mistral - Phi @@ -15,50 +16,50 @@ The tool currently supports the following model architectures. For all available options, please use the `-h/--help` flag. ``` # From wheel: -python -m onnxruntime_genai.models.builder --help +python3 -m onnxruntime_genai.models.builder --help # From source: -python builder.py --help +python3 builder.py --help ``` ### Original Model From Hugging Face This scenario is where your PyTorch model is not downloaded locally (either in the default Hugging Face cache directory or in a local folder on disk). ``` # From wheel: -python -m onnxruntime_genai.models.builder -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_to_save_hf_files +python3 -m onnxruntime_genai.models.builder -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_to_save_hf_files # From source: -python builder.py -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_to_save_hf_files +python3 builder.py -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_to_save_hf_files ``` ### Original Model From Disk This scenario is where your PyTorch model is already downloaded locally (either in the default Hugging Face cache directory or in a local folder on disk). ``` # From wheel: -python -m onnxruntime_genai.models.builder -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_where_hf_files_are_saved +python3 -m onnxruntime_genai.models.builder -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_where_hf_files_are_saved # From source: -python builder.py -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_where_hf_files_are_saved +python3 builder.py -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_where_hf_files_are_saved ``` ### Customized or Finetuned Model This scenario is where your PyTorch model has been customized or finetuned for one of the currently supported model architectures and your model can be loaded in Hugging Face. ``` # From wheel: -python -m onnxruntime_genai.models.builder -m path_to_local_folder_on_disk -o /path/to/output/folder -p precision -e execution_provider +python3 -m onnxruntime_genai.models.builder -m path_to_local_folder_on_disk -o /path/to/output/folder -p precision -e execution_provider # From source: -python builder.py -m path_to_local_folder_on_disk -o /path/to/output/folder -p precision -e execution_provider +python3 builder.py -m path_to_local_folder_on_disk -o /path/to/output/folder -p precision -e execution_provider ``` ### Extra Options This scenario is for when you want to have control over some specific settings. The below example shows how you can pass key-value arguments to `--extra_options`. ``` # From wheel: -python -m onnxruntime_genai.models.builder -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_to_save_hf_files --extra_options filename=decoder.onnx +python3 -m onnxruntime_genai.models.builder -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_to_save_hf_files --extra_options filename=decoder.onnx # From source: -python builder.py -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_to_save_hf_files --extra_options filename=decoder.onnx +python3 builder.py -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_to_save_hf_files --extra_options filename=decoder.onnx ``` To see all available options through `--extra_options`, please use the `help` commands in the `Full Usage` section above. @@ -82,10 +83,10 @@ tokenizer.save_pretrained(cache_dir) This option is the simplest but it will download another copy of the PyTorch model onto disk to accommodate the change in the number of hidden layers. ``` # From wheel: -python -m onnxruntime_genai.models.builder -m model_name -o /path/to/output/folder -p precision -e execution_provider --extra_options num_hidden_layers=4 +python3 -m onnxruntime_genai.models.builder -m model_name -o /path/to/output/folder -p precision -e execution_provider --extra_options num_hidden_layers=4 # From source: -python builder.py -m model_name -o /path/to/output/folder -p precision -e execution_provider --extra_options num_hidden_layers=4 +python3 builder.py -m model_name -o /path/to/output/folder -p precision -e execution_provider --extra_options num_hidden_layers=4 ``` #### Option 2: Edit the config.json file on disk and then run the model builder tool @@ -96,8 +97,8 @@ python builder.py -m model_name -o /path/to/output/folder -p precision -e execut ``` # From wheel: -python -m onnxruntime_genai.models.builder -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_where_hf_files_are_saved +python3 -m onnxruntime_genai.models.builder -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_where_hf_files_are_saved # From source: -python builder.py -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_where_hf_files_are_saved +python3 builder.py -m model_name -o /path/to/output/folder -p precision -e execution_provider -c cache_dir_where_hf_files_are_saved ``` \ No newline at end of file diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index 12f9aa272..e327b0f35 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -82,6 +82,11 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): "total_seq_len": "", # Size of total sequence length in attention mask (used as input to GroupQueryAttention) } + # Embedding-specific variables + self.embed_attrs = { + "normalize": False, # Normalize output of Embedding layer + } + # LayerNorm-specific variables self.layernorm_attrs = { "simple": True, # Use SimplifiedLayerNorm/SkipSimplifiedLayerNorm vs. LayerNorm/SkipLayerNorm @@ -91,6 +96,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): "skip_input": "", # Skip input from parent node for SkipLayerNorm "output_0": "", # Output 0 for LayerNorm and SkipLayerNorm "output_3": "", # Output 3 for SkipLayerNorm + "add_offset": 0, # Offset value for LayerNorm weight } # RotaryEmbedding-specific variables @@ -98,6 +104,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): "create_rotary_embedding_caches": True, # Create cos/sin caches for rotary embeddings "num_heads": 0, # For partial rotary embeddings (RotaryEmbedding kernel expects a default value of 0) "rotary_embedding_dim": 0, # For partial rotary embeddings (RotaryEmbedding kernel expects a default value of 0) + "theta": config.rope_theta, # Base value if calculating cos/sin caches from scratch } # Attention-specific variables (MHA, GQA, GQA + Rot.Emb., etc.) @@ -151,14 +158,13 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir): } print(f"Saving GenAI config in {out_dir}") - with open(os.path.join(os.path.dirname(out_dir),"genai_config.json"), "w") as f: + with open(os.path.join(out_dir,"genai_config.json"), "w") as f: json.dump(genai_config, f, indent=4) def save_processing(self, model_name_or_path, extra_kwargs, out_dir): tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, **extra_kwargs) - dest_dir = os.path.dirname(out_dir) - print(f"Saving processing files in {dest_dir} for GenAI") - tokenizer.save_pretrained(dest_dir) + print(f"Saving processing files in {out_dir} for GenAI") + tokenizer.save_pretrained(out_dir) def save_model(self, out_dir): print(f"Saving ONNX model in {out_dir}") @@ -198,8 +204,8 @@ def save_model(self, out_dir): model = self.to_int4(model) # Save ONNX model with only one external data file and delete any existing duplicate copies - out_path = os.path.join(os.path.dirname(out_dir), self.filename) - data_path = os.path.join(os.path.dirname(out_dir), os.path.basename(out_path) + ".data") + out_path = os.path.join(out_dir, self.filename) + data_path = os.path.join(out_dir, os.path.basename(out_path) + ".data") if os.path.exists(out_path): print(f"Overwriting {out_path}") os.remove(out_path) @@ -459,20 +465,34 @@ def make_embedding(self, embedding): weight = "model.embed_tokens.weight" self.make_external_tensor(embedding.astype(self.to_numpy_dtype[self.io_dtype]), weight) - name = "/model/embed_tokens/Gather" - output = f"{name}/output_0" - self.make_node('Gather', inputs=[weight, 'input_ids'], outputs=[output], name=name) - self.make_value_info(output, self.io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size]) + basename = "/model/embed_tokens" + gather_name = f"{basename}/Gather" + gather_output = f"{gather_name}/output_0" + self.make_node('Gather', inputs=[weight, 'input_ids'], outputs=[gather_output], name=gather_name) + self.make_value_info(gather_output, self.io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size]) + + if self.embed_attrs["normalize"]: + # Normalize the embeddings + norm_name = f"{basename}/Mul" + norm_inputs = [gather_output, f"/model/constants/{self.to_str_dtype[self.io_dtype]}/0D/{np.round(np.sqrt(self.hidden_size), decimals=2)}"] + norm_output = f"{norm_name}/output_0" + self.make_node('Mul', inputs=norm_inputs, outputs=[norm_output], name=norm_name) + self.make_value_info(norm_output, self.io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size]) + + layernorm_attrs_value = norm_output + else: + layernorm_attrs_value = gather_output + + self.layernorm_attrs["root_input"] = layernorm_attrs_value + self.layernorm_attrs["skip_input"] = layernorm_attrs_value - self.layernorm_attrs["root_input"] = output - self.layernorm_attrs["skip_input"] = output def make_layernorm(self, layer_id, layernorm, skip, simple, location): root_input = self.layernorm_attrs["root_input"] skip_input = self.layernorm_attrs["skip_input"] weight = f"model.layers.{layer_id}.{location}_layernorm.weight" - self.make_external_tensor(layernorm.weight.detach().numpy().astype(self.to_numpy_dtype[self.io_dtype]), weight) + self.make_external_tensor(layernorm.weight.detach().numpy().astype(self.to_numpy_dtype[self.io_dtype]) + self.layernorm_attrs["add_offset"], weight) bias = f"model.layers.{layer_id}.{location}_layernorm.bias" if not simple: self.make_external_tensor(layernorm.bias.detach().numpy().astype(self.to_numpy_dtype[self.io_dtype]), bias) @@ -510,11 +530,21 @@ def make_rotary_embedding(self, rotemb, name, root_input, **kwargs): cos_cache_name, sin_cache_name = "cos_cache", "sin_cache" if self.rotemb_attrs["create_rotary_embedding_caches"]: + if not hasattr(rotemb, "cos_cached"): + # Create cos/sin caches if not already created + inv_freq = 1.0 / (self.rotemb_attrs["theta"] ** (torch.arange(0, self.head_size, 2, dtype=torch.int64).float() / self.head_size)) + t = torch.arange(self.context_length, dtype=torch.int64).type_as(inv_freq) + freqs = torch.outer(t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + cos_cache, sin_cache = emb.cos(), emb.sin() + else: + cos_cache, sin_cache = rotemb.cos_cached, rotemb.sin_cached + # Reshape cos/sin cache from (M, H) to (M, H/2) - hidden_dim = rotemb.cos_cached.shape[-1] - cos_cache = rotemb.cos_cached.squeeze()[:, : (hidden_dim // 2)].detach().numpy() + hidden_dim = cos_cache.shape[-1] + cos_cache = cos_cache.squeeze()[:, : (hidden_dim // 2)].detach().numpy() self.make_external_tensor(cos_cache.astype(self.to_numpy_dtype[self.io_dtype]), cos_cache_name) - sin_cache = rotemb.sin_cached.squeeze()[:, : (hidden_dim // 2)].detach().numpy() + sin_cache = sin_cache.squeeze()[:, : (hidden_dim // 2)].detach().numpy() self.make_external_tensor(sin_cache.astype(self.to_numpy_dtype[self.io_dtype]), sin_cache_name) self.rotemb_attrs["create_rotary_embedding_caches"] = False @@ -681,7 +711,7 @@ def make_mlp(self, layer_id, mlp, root_input): # Assign output 0 of previous MatMul as skip input to next SkipLayerNorm self.layernorm_attrs["skip_input"] = f"{down_name}/output_0" - def make_activation_with_mul(self, layer_id, root_input, activation): + def make_activation_with_mul(self, layer_id, root_input, activation, domain): # Make nodes for this activation subgraph # # root_input (GateProjMatMul) @@ -691,7 +721,7 @@ def make_activation_with_mul(self, layer_id, root_input, activation): # Mul act_name = f"/model/layers.{layer_id}/mlp/act_fn/{activation}" act_output = f"{act_name}/output_0" - self.make_node(activation, inputs=[root_input], outputs=[act_output], name=act_name) + self.make_node(activation, inputs=[root_input], outputs=[act_output], name=act_name, domain=domain) self.make_value_info(act_output, dtype=self.io_dtype, shape=["batch_size", "sequence_length", self.intermediate_size]) mul_act_name = f"/model/layers.{layer_id}/mlp/act_fn/Mul" @@ -700,26 +730,26 @@ def make_activation_with_mul(self, layer_id, root_input, activation): return mul_act_name - def make_fast_gelu(self, layer_id, root_input): + def make_gelu(self, layer_id, root_input, activation): # Make nodes for this activation subgraph # # root_input (Add) # | - # FastGelu - fast_gelu_name = f"/model/layers.{layer_id}/mlp/act_fn/FastGelu" - output = f"{fast_gelu_name}/output_0" - self.make_node("FastGelu", inputs=[root_input], outputs=[output], name=fast_gelu_name, domain="com.microsoft") + # GeluAct + gelu_name = f"/model/layers.{layer_id}/mlp/act_fn/{activation}" + output = f"{gelu_name}/output_0" + self.make_node(activation, inputs=[root_input], outputs=[output], name=gelu_name, domain="com.microsoft") self.make_value_info(output, self.io_dtype, shape=['batch_size', 'sequence_length', self.intermediate_size]) - return fast_gelu_name + return gelu_name def make_activation(self, layer_id, root_input): if self.activation in {"silu", "swish"}: - output_name = self.make_activation_with_mul(layer_id, root_input, "Sigmoid") + output_name = self.make_activation_with_mul(layer_id, root_input, activation="Sigmoid", domain=None) elif self.activation in {"gelu_new", "gelu_fast"}: - output_name = self.make_fast_gelu(layer_id, root_input) + output_name = self.make_gelu(layer_id, root_input, activation="FastGelu") elif self.activation in {"gelu"}: - output_name = self.make_activation_with_mul(layer_id, root_input, "Gelu") + output_name = self.make_gelu(layer_id, root_input, activation="Gelu") else: raise NotImplementedError(f"The {self.activation} activation function is not currently supported.") return output_name @@ -1298,6 +1328,13 @@ def make_layer(self, layer_id, layer): self.layernorm_attrs["skip_input"] = f"{residual_add_name}/output_0" +class GemmaModel(MistralModel): + def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): + super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) + self.embed_attrs["normalize"] = True + self.layernorm_attrs["add_offset"] = 1 + + def parse_extra_options(kv_items): """ Parse key value pairs that are separated by '=' @@ -1314,6 +1351,7 @@ def parse_extra_options(kv_items): def create_model(model_name_or_path, output_dir, precision, execution_provider, cache_dir, **extra_options): + os.makedirs(output_dir, exist_ok=True) os.makedirs(cache_dir, exist_ok=True) extra_kwargs = {} if os.path.exists(model_name_or_path) else {"cache_dir": cache_dir, "use_auth_token": True} config = AutoConfig.from_pretrained(model_name_or_path, **extra_kwargs) @@ -1321,7 +1359,10 @@ def create_model(model_name_or_path, output_dir, precision, execution_provider, # Set input/output precision of ONNX model io_dtype = TensorProto.FLOAT if precision in {"int8", "fp32"} or (precision == "int4" and execution_provider == "cpu") else TensorProto.FLOAT16 - if config.architectures[0] == "LlamaForCausalLM": + # List architecture options in alphabetical order + if config.architectures[0] == "GemmaForCausalLM": + onnx_model = GemmaModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options) + elif config.architectures[0] == "LlamaForCausalLM": onnx_model = LlamaModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options) elif config.architectures[0] == "MistralForCausalLM": onnx_model = MistralModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options) @@ -1357,7 +1398,7 @@ def get_args(): "-o", "--output", required=True, - help="Path to folder containing ONNX model and additional files (e.g. GenAI config, external data files, etc.)", + help="Path to folder to store ONNX model and additional files (e.g. GenAI config, external data files, etc.)", ) parser.add_argument( @@ -1381,7 +1422,7 @@ def get_args(): "--cache_dir", required=False, type=str, - default=os.path.join(os.path.dirname(os.path.realpath(__file__)), 'cache_dir'), + default=os.path.join('.', 'cache_dir'), help="Model cache directory (if providing model name and not folder path)", )